From 76f593143b27dfa151f380a4afeebebe6dfa381e Mon Sep 17 00:00:00 2001 From: Lance Galletti Date: Wed, 2 Jul 2025 15:08:10 -0400 Subject: [PATCH] tool index refreshing when token is provided to create_turn --- docs/_static/llama-stack-spec.html | 39 ++++++++ docs/_static/llama-stack-spec.yaml | 28 ++++++ llama_stack/apis/tools/tools.py | 11 +++ .../distribution/routing_tables/toolgroups.py | 91 +++++++++++++++---- llama_stack/models/llama/datatypes.py | 1 + .../agents/meta_reference/agent_instance.py | 34 +++++++ .../model_context_protocol.py | 34 ++++++- 7 files changed, 217 insertions(+), 21 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index db5c57821..7db6b4918 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4566,6 +4566,42 @@ } } }, + "/v1/toolgroups/{toolgroup_id}/refresh": { + "post": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "ToolGroups" + ], + "description": "Refresh tools for a specific toolgroup.", + "parameters": [ + { + "name": "toolgroup_id", + "in": "path", + "description": "The ID of the toolgroup to refresh tools for.", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": { "post": { "responses": { @@ -5588,6 +5624,9 @@ "additionalProperties": { "$ref": "#/components/schemas/ToolParamDefinition" } + }, + "toolgroup_name": { + "type": "string" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 29ba9dede..25475dddb 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3219,6 +3219,32 @@ paths: schema: $ref: '#/components/schemas/QueryTracesRequest' required: true + /v1/toolgroups/{toolgroup_id}/refresh: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolGroups + description: Refresh tools for a specific toolgroup. + parameters: + - name: toolgroup_id + in: path + description: >- + The ID of the toolgroup to refresh tools for. + required: true + schema: + type: string /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume: post: responses: @@ -3940,6 +3966,8 @@ components: type: object additionalProperties: $ref: '#/components/schemas/ToolParamDefinition' + toolgroup_name: + type: string additionalProperties: false required: - tool_name diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 7d1eeeefb..668511486 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -145,6 +145,17 @@ class ToolGroups(Protocol): """ ... + @webmethod(route="/toolgroups/{toolgroup_id:path}/refresh", method="POST") + async def refresh_tools( + self, + toolgroup_id: str, + ) -> None: + """Refresh tools for a specific toolgroup. + + :param toolgroup_id: The ID of the toolgroup to refresh tools for. + """ + ... + @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE") async def unregister_toolgroup( self, diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index b86f057bd..aad406a34 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -43,6 +43,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return super().get_provider_impl(routing_key, provider_id) async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: + logger.debug(f"Listing tools for toolgroup_id: {toolgroup_id}") + if toolgroup_id: if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id): toolgroup_id = group_id @@ -53,33 +55,59 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): all_tools = [] for toolgroup in toolgroups: if toolgroup.identifier not in self.toolgroups_to_tools: + logger.debug(f"Toolgroup {toolgroup.identifier} not in cache, indexing...") await self._index_tools(toolgroup) - all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier]) + cached_tools = self.toolgroups_to_tools.get(toolgroup.identifier, []) + logger.debug(f"Found {len(cached_tools)} cached tools for toolgroup {toolgroup.identifier}") + all_tools.extend(cached_tools) + + logger.debug(f"Returning {len(all_tools)} total tools") return ListToolsResponse(data=all_tools) async def _index_tools(self, toolgroup: ToolGroup): - provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) - tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) + try: + provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) + logger.debug(f"Indexing tools for toolgroup {toolgroup.identifier} with provider {toolgroup.provider_id}") - # TODO: kill this Tool vs ToolDef distinction - tooldefs = tooldefs_response.data - tools = [] - for t in tooldefs: - tools.append( - Tool( - identifier=t.name, - toolgroup_id=toolgroup.identifier, - description=t.description or "", - parameters=t.parameters or [], - metadata=t.metadata, - provider_id=toolgroup.provider_id, + if toolgroup.mcp_endpoint: + logger.debug(f"Toolgroup {toolgroup.identifier} has MCP endpoint: {toolgroup.mcp_endpoint.uri}") + + tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) + + # TODO: kill this Tool vs ToolDef distinction + tooldefs = tooldefs_response.data + tools = [] + for t in tooldefs: + tools.append( + Tool( + identifier=t.name, + toolgroup_id=toolgroup.identifier, + description=t.description or "", + parameters=t.parameters or [], + metadata=t.metadata, + provider_id=toolgroup.provider_id, + ) ) - ) - self.toolgroups_to_tools[toolgroup.identifier] = tools - for tool in tools: - self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier + self.toolgroups_to_tools[toolgroup.identifier] = tools + for tool in tools: + self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier + + logger.info(f"Successfully indexed {len(tools)} tools for toolgroup {toolgroup.identifier}") + + except Exception as e: + logger.warning(f"Failed to index tools for toolgroup {toolgroup.identifier}: {e}") + # Don't let tool indexing failures crash the system + # Initialize empty tools list so the toolgroup still exists + self.toolgroups_to_tools[toolgroup.identifier] = [] + if toolgroup.mcp_endpoint: + logger.info( + f"Toolgroup {toolgroup.identifier} has MCP endpoint - tools may be available after authentication" + ) + else: + logger.error(f"Non-MCP toolgroup {toolgroup.identifier} failed to index tools: {e}") + # Don't raise - we want the system to continue running even if tool indexing fails async def list_tool_groups(self) -> ListToolGroupsResponse: return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) @@ -119,7 +147,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): # the tools should first list the tools and then use them. but there are assumptions # baked in some of the code and tests right now. if not toolgroup.mcp_endpoint: - await self._index_tools(toolgroup) + try: + await self._index_tools(toolgroup) + except Exception as e: + logger.error(f"Failed to index tools during toolgroup registration for {toolgroup_id}: {e}") + # Don't fail the registration - the toolgroup can still be used + # Tools may become available later return toolgroup async def unregister_toolgroup(self, toolgroup_id: str) -> None: @@ -128,5 +161,23 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): raise ValueError(f"Tool group {toolgroup_id} not found") await self.unregister_object(tool_group) + async def refresh_tools(self, toolgroup_id: str) -> None: + """Refresh tools for a specific toolgroup, useful for re-indexing after auth becomes available.""" + try: + toolgroup = await self.get_tool_group(toolgroup_id) + # Clear existing tools for this toolgroup + if toolgroup_id in self.toolgroups_to_tools: + old_tools = self.toolgroups_to_tools[toolgroup_id] + for tool in old_tools: + if tool.identifier in self.tool_to_toolgroup: + del self.tool_to_toolgroup[tool.identifier] + del self.toolgroups_to_tools[toolgroup_id] + + # Re-index tools for this toolgroup + await self._index_tools(toolgroup) + except Exception as e: + # Log error but don't fail - tools may become available later + logger.warning(f"Failed to refresh tools for toolgroup {toolgroup_id}: {e}") + async def shutdown(self) -> None: pass diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index 7f1ebed55..db45ec49d 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -99,6 +99,7 @@ class ToolDefinition(BaseModel): tool_name: BuiltinTool | str description: str | None = None parameters: dict[str, ToolParamDefinition] | None = None + toolgroup_name: str | None = None @field_validator("tool_name", mode="before") @classmethod diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 4d2b9f8bf..6d156847c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -185,6 +185,9 @@ class ChatAgent(ShieldRunnerMixin): if self.agent_config.name: span.set_attribute("agent_name", self.agent_config.name) + # Refresh tools for MCP toolgroups in case auth is now available + await self._refresh_mcp_tools(request.toolgroups) + await self._initialize_tools(request.toolgroups) async for chunk in self._run_turn(request, turn_id): yield chunk @@ -762,6 +765,35 @@ class ChatAgent(ShieldRunnerMixin): yield client_message return + async def _refresh_mcp_tools( + self, + toolgroups_for_turn: list[AgentToolGroup] | None = None, + ) -> None: + """Refresh MCP tools in case authentication is now available.""" + # Determine which tools to refresh + tool_groups_to_refresh = toolgroups_for_turn or self.agent_config.toolgroups or [] + + logger.debug(f"Refreshing MCP tools for {len(tool_groups_to_refresh)} toolgroups") + + for toolgroup in tool_groups_to_refresh: + name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup + toolgroup_name, _ = self._parse_toolgroup_name(name) + + # Only refresh MCP toolgroups (those with endpoints) + try: + tool_group = await self.tool_groups_api.get_tool_group(toolgroup_name) + if tool_group.mcp_endpoint: + logger.debug( + f"Refreshing tools for MCP toolgroup {toolgroup_name} with endpoint {tool_group.mcp_endpoint.uri}" + ) + # Refresh tools for this MCP toolgroup + await self.tool_groups_api.refresh_tools(toolgroup_name) + else: + logger.debug(f"Toolgroup {toolgroup_name} is not an MCP toolgroup, skipping refresh") + except Exception as e: + # Log but don't fail - tools may become available later + logger.warning(f"Failed to refresh MCP tools for toolgroup {toolgroup_name}: {e}") + async def _initialize_tools( self, toolgroups_for_turn: list[AgentToolGroup] | None = None, @@ -800,6 +832,7 @@ class ChatAgent(ShieldRunnerMixin): ) for param in tool_def.parameters }, + toolgroup_name="client_tools", ) for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) @@ -843,6 +876,7 @@ class ChatAgent(ShieldRunnerMixin): ) for param in tool_def.parameters }, + toolgroup_name=toolgroup_name, ) tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {}) diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index a9b252dfe..195454864 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -15,6 +15,7 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ToolRuntime, ) +from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate @@ -44,14 +45,45 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") + + logger.debug(f"Listing runtime tools for toolgroup {tool_group_id} at endpoint {mcp_endpoint.uri}") headers = await self.get_headers_from_request(mcp_endpoint.uri) - return await list_mcp_tools(mcp_endpoint.uri, headers) + + if headers: + logger.debug(f"Found {len(headers)} headers for MCP endpoint {mcp_endpoint.uri}") + # Log header keys but not values for security + header_keys = list(headers.keys()) + logger.debug(f"Header keys: {header_keys}") + else: + logger.debug(f"No headers found for MCP endpoint {mcp_endpoint.uri}") + + try: + result = await list_mcp_tools(mcp_endpoint.uri, headers) + logger.info(f"Successfully listed {len(result.data)} tools for toolgroup {tool_group_id}") + return result + except AuthenticationRequiredError as e: + logger.warning(f"Authentication required for MCP endpoint {mcp_endpoint.uri}: {e}") + logger.info( + f"Returning empty tool list for toolgroup {tool_group_id} - tools will be refreshed when authentication is available" + ) + # Return empty list on authentication errors during startup + # Tools will be refreshed when a turn is created with proper auth + return ListToolDefsResponse(data=[]) + except Exception as e: + logger.error(f"Failed to list tools for toolgroup {tool_group_id} at endpoint {mcp_endpoint.uri}: {e}") + # Return empty list on other errors too to prevent crashes + return ListToolDefsResponse(data=[]) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + if self.tool_store is None: + raise ValueError(f"Tool store is not available for tool {tool_name}") + tool = await self.tool_store.get_tool(tool_name) if tool.metadata is None or tool.metadata.get("endpoint") is None: raise ValueError(f"Tool {tool_name} does not have metadata") endpoint = tool.metadata.get("endpoint") + if endpoint is None: + raise ValueError(f"Tool {tool_name} does not have an endpoint") if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")