diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index db5c57821..7db6b4918 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4566,6 +4566,42 @@
}
}
},
+ "/v1/toolgroups/{toolgroup_id}/refresh": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK"
+ },
+ "400": {
+ "$ref": "#/components/responses/BadRequest400"
+ },
+ "429": {
+ "$ref": "#/components/responses/TooManyRequests429"
+ },
+ "500": {
+ "$ref": "#/components/responses/InternalServerError500"
+ },
+ "default": {
+ "$ref": "#/components/responses/DefaultError"
+ }
+ },
+ "tags": [
+ "ToolGroups"
+ ],
+ "description": "Refresh tools for a specific toolgroup.",
+ "parameters": [
+ {
+ "name": "toolgroup_id",
+ "in": "path",
+ "description": "The ID of the toolgroup to refresh tools for.",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
+ }
+ },
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
"post": {
"responses": {
@@ -5588,6 +5624,9 @@
"additionalProperties": {
"$ref": "#/components/schemas/ToolParamDefinition"
}
+ },
+ "toolgroup_name": {
+ "type": "string"
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 29ba9dede..25475dddb 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -3219,6 +3219,32 @@ paths:
schema:
$ref: '#/components/schemas/QueryTracesRequest'
required: true
+ /v1/toolgroups/{toolgroup_id}/refresh:
+ post:
+ responses:
+ '200':
+ description: OK
+ '400':
+ $ref: '#/components/responses/BadRequest400'
+ '429':
+ $ref: >-
+ #/components/responses/TooManyRequests429
+ '500':
+ $ref: >-
+ #/components/responses/InternalServerError500
+ default:
+ $ref: '#/components/responses/DefaultError'
+ tags:
+ - ToolGroups
+ description: Refresh tools for a specific toolgroup.
+ parameters:
+ - name: toolgroup_id
+ in: path
+ description: >-
+ The ID of the toolgroup to refresh tools for.
+ required: true
+ schema:
+ type: string
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
post:
responses:
@@ -3940,6 +3966,8 @@ components:
type: object
additionalProperties:
$ref: '#/components/schemas/ToolParamDefinition'
+ toolgroup_name:
+ type: string
additionalProperties: false
required:
- tool_name
diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py
index 7d1eeeefb..668511486 100644
--- a/llama_stack/apis/tools/tools.py
+++ b/llama_stack/apis/tools/tools.py
@@ -145,6 +145,17 @@ class ToolGroups(Protocol):
"""
...
+ @webmethod(route="/toolgroups/{toolgroup_id:path}/refresh", method="POST")
+ async def refresh_tools(
+ self,
+ toolgroup_id: str,
+ ) -> None:
+ """Refresh tools for a specific toolgroup.
+
+ :param toolgroup_id: The ID of the toolgroup to refresh tools for.
+ """
+ ...
+
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
async def unregister_toolgroup(
self,
diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py
index b86f057bd..aad406a34 100644
--- a/llama_stack/distribution/routing_tables/toolgroups.py
+++ b/llama_stack/distribution/routing_tables/toolgroups.py
@@ -43,6 +43,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
+ logger.debug(f"Listing tools for toolgroup_id: {toolgroup_id}")
+
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
@@ -53,33 +55,59 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
all_tools = []
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
+ logger.debug(f"Toolgroup {toolgroup.identifier} not in cache, indexing...")
await self._index_tools(toolgroup)
- all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
+ cached_tools = self.toolgroups_to_tools.get(toolgroup.identifier, [])
+ logger.debug(f"Found {len(cached_tools)} cached tools for toolgroup {toolgroup.identifier}")
+ all_tools.extend(cached_tools)
+
+ logger.debug(f"Returning {len(all_tools)} total tools")
return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
- provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
- tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
+ try:
+ provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
+ logger.debug(f"Indexing tools for toolgroup {toolgroup.identifier} with provider {toolgroup.provider_id}")
- # TODO: kill this Tool vs ToolDef distinction
- tooldefs = tooldefs_response.data
- tools = []
- for t in tooldefs:
- tools.append(
- Tool(
- identifier=t.name,
- toolgroup_id=toolgroup.identifier,
- description=t.description or "",
- parameters=t.parameters or [],
- metadata=t.metadata,
- provider_id=toolgroup.provider_id,
+ if toolgroup.mcp_endpoint:
+ logger.debug(f"Toolgroup {toolgroup.identifier} has MCP endpoint: {toolgroup.mcp_endpoint.uri}")
+
+ tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
+
+ # TODO: kill this Tool vs ToolDef distinction
+ tooldefs = tooldefs_response.data
+ tools = []
+ for t in tooldefs:
+ tools.append(
+ Tool(
+ identifier=t.name,
+ toolgroup_id=toolgroup.identifier,
+ description=t.description or "",
+ parameters=t.parameters or [],
+ metadata=t.metadata,
+ provider_id=toolgroup.provider_id,
+ )
)
- )
- self.toolgroups_to_tools[toolgroup.identifier] = tools
- for tool in tools:
- self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
+ self.toolgroups_to_tools[toolgroup.identifier] = tools
+ for tool in tools:
+ self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
+
+ logger.info(f"Successfully indexed {len(tools)} tools for toolgroup {toolgroup.identifier}")
+
+ except Exception as e:
+ logger.warning(f"Failed to index tools for toolgroup {toolgroup.identifier}: {e}")
+ # Don't let tool indexing failures crash the system
+ # Initialize empty tools list so the toolgroup still exists
+ self.toolgroups_to_tools[toolgroup.identifier] = []
+ if toolgroup.mcp_endpoint:
+ logger.info(
+ f"Toolgroup {toolgroup.identifier} has MCP endpoint - tools may be available after authentication"
+ )
+ else:
+ logger.error(f"Non-MCP toolgroup {toolgroup.identifier} failed to index tools: {e}")
+ # Don't raise - we want the system to continue running even if tool indexing fails
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
@@ -119,7 +147,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
# the tools should first list the tools and then use them. but there are assumptions
# baked in some of the code and tests right now.
if not toolgroup.mcp_endpoint:
- await self._index_tools(toolgroup)
+ try:
+ await self._index_tools(toolgroup)
+ except Exception as e:
+ logger.error(f"Failed to index tools during toolgroup registration for {toolgroup_id}: {e}")
+ # Don't fail the registration - the toolgroup can still be used
+ # Tools may become available later
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
@@ -128,5 +161,23 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
raise ValueError(f"Tool group {toolgroup_id} not found")
await self.unregister_object(tool_group)
+ async def refresh_tools(self, toolgroup_id: str) -> None:
+ """Refresh tools for a specific toolgroup, useful for re-indexing after auth becomes available."""
+ try:
+ toolgroup = await self.get_tool_group(toolgroup_id)
+ # Clear existing tools for this toolgroup
+ if toolgroup_id in self.toolgroups_to_tools:
+ old_tools = self.toolgroups_to_tools[toolgroup_id]
+ for tool in old_tools:
+ if tool.identifier in self.tool_to_toolgroup:
+ del self.tool_to_toolgroup[tool.identifier]
+ del self.toolgroups_to_tools[toolgroup_id]
+
+ # Re-index tools for this toolgroup
+ await self._index_tools(toolgroup)
+ except Exception as e:
+ # Log error but don't fail - tools may become available later
+ logger.warning(f"Failed to refresh tools for toolgroup {toolgroup_id}: {e}")
+
async def shutdown(self) -> None:
pass
diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py
index 7f1ebed55..db45ec49d 100644
--- a/llama_stack/models/llama/datatypes.py
+++ b/llama_stack/models/llama/datatypes.py
@@ -99,6 +99,7 @@ class ToolDefinition(BaseModel):
tool_name: BuiltinTool | str
description: str | None = None
parameters: dict[str, ToolParamDefinition] | None = None
+ toolgroup_name: str | None = None
@field_validator("tool_name", mode="before")
@classmethod
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index 4d2b9f8bf..6d156847c 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -185,6 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
+ # Refresh tools for MCP toolgroups in case auth is now available
+ await self._refresh_mcp_tools(request.toolgroups)
+
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
yield chunk
@@ -762,6 +765,35 @@ class ChatAgent(ShieldRunnerMixin):
yield client_message
return
+ async def _refresh_mcp_tools(
+ self,
+ toolgroups_for_turn: list[AgentToolGroup] | None = None,
+ ) -> None:
+ """Refresh MCP tools in case authentication is now available."""
+ # Determine which tools to refresh
+ tool_groups_to_refresh = toolgroups_for_turn or self.agent_config.toolgroups or []
+
+ logger.debug(f"Refreshing MCP tools for {len(tool_groups_to_refresh)} toolgroups")
+
+ for toolgroup in tool_groups_to_refresh:
+ name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
+ toolgroup_name, _ = self._parse_toolgroup_name(name)
+
+ # Only refresh MCP toolgroups (those with endpoints)
+ try:
+ tool_group = await self.tool_groups_api.get_tool_group(toolgroup_name)
+ if tool_group.mcp_endpoint:
+ logger.debug(
+ f"Refreshing tools for MCP toolgroup {toolgroup_name} with endpoint {tool_group.mcp_endpoint.uri}"
+ )
+ # Refresh tools for this MCP toolgroup
+ await self.tool_groups_api.refresh_tools(toolgroup_name)
+ else:
+ logger.debug(f"Toolgroup {toolgroup_name} is not an MCP toolgroup, skipping refresh")
+ except Exception as e:
+ # Log but don't fail - tools may become available later
+ logger.warning(f"Failed to refresh MCP tools for toolgroup {toolgroup_name}: {e}")
+
async def _initialize_tools(
self,
toolgroups_for_turn: list[AgentToolGroup] | None = None,
@@ -800,6 +832,7 @@ class ChatAgent(ShieldRunnerMixin):
)
for param in tool_def.parameters
},
+ toolgroup_name="client_tools",
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
@@ -843,6 +876,7 @@ class ChatAgent(ShieldRunnerMixin):
)
for param in tool_def.parameters
},
+ toolgroup_name=toolgroup_name,
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
index a9b252dfe..195454864 100644
--- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
+++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
@@ -15,6 +15,7 @@ from llama_stack.apis.tools import (
ToolInvocationResult,
ToolRuntime,
)
+from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
@@ -44,14 +45,45 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
+
+ logger.debug(f"Listing runtime tools for toolgroup {tool_group_id} at endpoint {mcp_endpoint.uri}")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
- return await list_mcp_tools(mcp_endpoint.uri, headers)
+
+ if headers:
+ logger.debug(f"Found {len(headers)} headers for MCP endpoint {mcp_endpoint.uri}")
+ # Log header keys but not values for security
+ header_keys = list(headers.keys())
+ logger.debug(f"Header keys: {header_keys}")
+ else:
+ logger.debug(f"No headers found for MCP endpoint {mcp_endpoint.uri}")
+
+ try:
+ result = await list_mcp_tools(mcp_endpoint.uri, headers)
+ logger.info(f"Successfully listed {len(result.data)} tools for toolgroup {tool_group_id}")
+ return result
+ except AuthenticationRequiredError as e:
+ logger.warning(f"Authentication required for MCP endpoint {mcp_endpoint.uri}: {e}")
+ logger.info(
+ f"Returning empty tool list for toolgroup {tool_group_id} - tools will be refreshed when authentication is available"
+ )
+ # Return empty list on authentication errors during startup
+ # Tools will be refreshed when a turn is created with proper auth
+ return ListToolDefsResponse(data=[])
+ except Exception as e:
+ logger.error(f"Failed to list tools for toolgroup {tool_group_id} at endpoint {mcp_endpoint.uri}: {e}")
+ # Return empty list on other errors too to prevent crashes
+ return ListToolDefsResponse(data=[])
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
+ if self.tool_store is None:
+ raise ValueError(f"Tool store is not available for tool {tool_name}")
+
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
endpoint = tool.metadata.get("endpoint")
+ if endpoint is None:
+ raise ValueError(f"Tool {tool_name} does not have an endpoint")
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")