tool index refreshing when token is provided to create_turn

This commit is contained in:
Lance Galletti 2025-07-02 15:08:10 -04:00
parent 57745101be
commit 76f593143b
7 changed files with 217 additions and 21 deletions

View file

@ -185,6 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
# Refresh tools for MCP toolgroups in case auth is now available
await self._refresh_mcp_tools(request.toolgroups)
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
yield chunk
@ -762,6 +765,35 @@ class ChatAgent(ShieldRunnerMixin):
yield client_message
return
async def _refresh_mcp_tools(
self,
toolgroups_for_turn: list[AgentToolGroup] | None = None,
) -> None:
"""Refresh MCP tools in case authentication is now available."""
# Determine which tools to refresh
tool_groups_to_refresh = toolgroups_for_turn or self.agent_config.toolgroups or []
logger.debug(f"Refreshing MCP tools for {len(tool_groups_to_refresh)} toolgroups")
for toolgroup in tool_groups_to_refresh:
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
toolgroup_name, _ = self._parse_toolgroup_name(name)
# Only refresh MCP toolgroups (those with endpoints)
try:
tool_group = await self.tool_groups_api.get_tool_group(toolgroup_name)
if tool_group.mcp_endpoint:
logger.debug(
f"Refreshing tools for MCP toolgroup {toolgroup_name} with endpoint {tool_group.mcp_endpoint.uri}"
)
# Refresh tools for this MCP toolgroup
await self.tool_groups_api.refresh_tools(toolgroup_name)
else:
logger.debug(f"Toolgroup {toolgroup_name} is not an MCP toolgroup, skipping refresh")
except Exception as e:
# Log but don't fail - tools may become available later
logger.warning(f"Failed to refresh MCP tools for toolgroup {toolgroup_name}: {e}")
async def _initialize_tools(
self,
toolgroups_for_turn: list[AgentToolGroup] | None = None,
@ -800,6 +832,7 @@ class ChatAgent(ShieldRunnerMixin):
)
for param in tool_def.parameters
},
toolgroup_name="client_tools",
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
@ -843,6 +876,7 @@ class ChatAgent(ShieldRunnerMixin):
)
for param in tool_def.parameters
},
toolgroup_name=toolgroup_name,
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})

View file

@ -15,6 +15,7 @@ from llama_stack.apis.tools import (
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
@ -44,14 +45,45 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
logger.debug(f"Listing runtime tools for toolgroup {tool_group_id} at endpoint {mcp_endpoint.uri}")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(mcp_endpoint.uri, headers)
if headers:
logger.debug(f"Found {len(headers)} headers for MCP endpoint {mcp_endpoint.uri}")
# Log header keys but not values for security
header_keys = list(headers.keys())
logger.debug(f"Header keys: {header_keys}")
else:
logger.debug(f"No headers found for MCP endpoint {mcp_endpoint.uri}")
try:
result = await list_mcp_tools(mcp_endpoint.uri, headers)
logger.info(f"Successfully listed {len(result.data)} tools for toolgroup {tool_group_id}")
return result
except AuthenticationRequiredError as e:
logger.warning(f"Authentication required for MCP endpoint {mcp_endpoint.uri}: {e}")
logger.info(
f"Returning empty tool list for toolgroup {tool_group_id} - tools will be refreshed when authentication is available"
)
# Return empty list on authentication errors during startup
# Tools will be refreshed when a turn is created with proper auth
return ListToolDefsResponse(data=[])
except Exception as e:
logger.error(f"Failed to list tools for toolgroup {tool_group_id} at endpoint {mcp_endpoint.uri}: {e}")
# Return empty list on other errors too to prevent crashes
return ListToolDefsResponse(data=[])
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
if self.tool_store is None:
raise ValueError(f"Tool store is not available for tool {tool_name}")
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
endpoint = tool.metadata.get("endpoint")
if endpoint is None:
raise ValueError(f"Tool {tool_name} does not have an endpoint")
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")