mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
tool index refreshing when token is provided to create_turn
This commit is contained in:
parent
57745101be
commit
76f593143b
7 changed files with 217 additions and 21 deletions
39
docs/_static/llama-stack-spec.html
vendored
39
docs/_static/llama-stack-spec.html
vendored
|
|
@ -4566,6 +4566,42 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/toolgroups/{toolgroup_id}/refresh": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"ToolGroups"
|
||||
],
|
||||
"description": "Refresh tools for a specific toolgroup.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "toolgroup_id",
|
||||
"in": "path",
|
||||
"description": "The ID of the toolgroup to refresh tools for.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||
"post": {
|
||||
"responses": {
|
||||
|
|
@ -5588,6 +5624,9 @@
|
|||
"additionalProperties": {
|
||||
"$ref": "#/components/schemas/ToolParamDefinition"
|
||||
}
|
||||
},
|
||||
"toolgroup_name": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
|
|||
28
docs/_static/llama-stack-spec.yaml
vendored
28
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -3219,6 +3219,32 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/QueryTracesRequest'
|
||||
required: true
|
||||
/v1/toolgroups/{toolgroup_id}/refresh:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolGroups
|
||||
description: Refresh tools for a specific toolgroup.
|
||||
parameters:
|
||||
- name: toolgroup_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the toolgroup to refresh tools for.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -3940,6 +3966,8 @@ components:
|
|||
type: object
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/ToolParamDefinition'
|
||||
toolgroup_name:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- tool_name
|
||||
|
|
|
|||
|
|
@ -145,6 +145,17 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}/refresh", method="POST")
|
||||
async def refresh_tools(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
) -> None:
|
||||
"""Refresh tools for a specific toolgroup.
|
||||
|
||||
:param toolgroup_id: The ID of the toolgroup to refresh tools for.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -43,6 +43,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
return super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
logger.debug(f"Listing tools for toolgroup_id: {toolgroup_id}")
|
||||
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
|
@ -53,33 +55,59 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
all_tools = []
|
||||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||
logger.debug(f"Toolgroup {toolgroup.identifier} not in cache, indexing...")
|
||||
await self._index_tools(toolgroup)
|
||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||
|
||||
cached_tools = self.toolgroups_to_tools.get(toolgroup.identifier, [])
|
||||
logger.debug(f"Found {len(cached_tools)} cached tools for toolgroup {toolgroup.identifier}")
|
||||
all_tools.extend(cached_tools)
|
||||
|
||||
logger.debug(f"Returning {len(all_tools)} total tools")
|
||||
return ListToolsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
try:
|
||||
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
logger.debug(f"Indexing tools for toolgroup {toolgroup.identifier} with provider {toolgroup.provider_id}")
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
tooldefs = tooldefs_response.data
|
||||
tools = []
|
||||
for t in tooldefs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=t.name,
|
||||
toolgroup_id=toolgroup.identifier,
|
||||
description=t.description or "",
|
||||
parameters=t.parameters or [],
|
||||
metadata=t.metadata,
|
||||
provider_id=toolgroup.provider_id,
|
||||
if toolgroup.mcp_endpoint:
|
||||
logger.debug(f"Toolgroup {toolgroup.identifier} has MCP endpoint: {toolgroup.mcp_endpoint.uri}")
|
||||
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
tooldefs = tooldefs_response.data
|
||||
tools = []
|
||||
for t in tooldefs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=t.name,
|
||||
toolgroup_id=toolgroup.identifier,
|
||||
description=t.description or "",
|
||||
parameters=t.parameters or [],
|
||||
metadata=t.metadata,
|
||||
provider_id=toolgroup.provider_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
||||
for tool in tools:
|
||||
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
||||
for tool in tools:
|
||||
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
||||
|
||||
logger.info(f"Successfully indexed {len(tools)} tools for toolgroup {toolgroup.identifier}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to index tools for toolgroup {toolgroup.identifier}: {e}")
|
||||
# Don't let tool indexing failures crash the system
|
||||
# Initialize empty tools list so the toolgroup still exists
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = []
|
||||
if toolgroup.mcp_endpoint:
|
||||
logger.info(
|
||||
f"Toolgroup {toolgroup.identifier} has MCP endpoint - tools may be available after authentication"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Non-MCP toolgroup {toolgroup.identifier} failed to index tools: {e}")
|
||||
# Don't raise - we want the system to continue running even if tool indexing fails
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
|
@ -119,7 +147,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
# the tools should first list the tools and then use them. but there are assumptions
|
||||
# baked in some of the code and tests right now.
|
||||
if not toolgroup.mcp_endpoint:
|
||||
await self._index_tools(toolgroup)
|
||||
try:
|
||||
await self._index_tools(toolgroup)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to index tools during toolgroup registration for {toolgroup_id}: {e}")
|
||||
# Don't fail the registration - the toolgroup can still be used
|
||||
# Tools may become available later
|
||||
return toolgroup
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
|
|
@ -128,5 +161,23 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
await self.unregister_object(tool_group)
|
||||
|
||||
async def refresh_tools(self, toolgroup_id: str) -> None:
|
||||
"""Refresh tools for a specific toolgroup, useful for re-indexing after auth becomes available."""
|
||||
try:
|
||||
toolgroup = await self.get_tool_group(toolgroup_id)
|
||||
# Clear existing tools for this toolgroup
|
||||
if toolgroup_id in self.toolgroups_to_tools:
|
||||
old_tools = self.toolgroups_to_tools[toolgroup_id]
|
||||
for tool in old_tools:
|
||||
if tool.identifier in self.tool_to_toolgroup:
|
||||
del self.tool_to_toolgroup[tool.identifier]
|
||||
del self.toolgroups_to_tools[toolgroup_id]
|
||||
|
||||
# Re-index tools for this toolgroup
|
||||
await self._index_tools(toolgroup)
|
||||
except Exception as e:
|
||||
# Log error but don't fail - tools may become available later
|
||||
logger.warning(f"Failed to refresh tools for toolgroup {toolgroup_id}: {e}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -99,6 +99,7 @@ class ToolDefinition(BaseModel):
|
|||
tool_name: BuiltinTool | str
|
||||
description: str | None = None
|
||||
parameters: dict[str, ToolParamDefinition] | None = None
|
||||
toolgroup_name: str | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -185,6 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
# Refresh tools for MCP toolgroups in case auth is now available
|
||||
await self._refresh_mcp_tools(request.toolgroups)
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
|
@ -762,6 +765,35 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield client_message
|
||||
return
|
||||
|
||||
async def _refresh_mcp_tools(
|
||||
self,
|
||||
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
||||
) -> None:
|
||||
"""Refresh MCP tools in case authentication is now available."""
|
||||
# Determine which tools to refresh
|
||||
tool_groups_to_refresh = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||
|
||||
logger.debug(f"Refreshing MCP tools for {len(tool_groups_to_refresh)} toolgroups")
|
||||
|
||||
for toolgroup in tool_groups_to_refresh:
|
||||
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
||||
toolgroup_name, _ = self._parse_toolgroup_name(name)
|
||||
|
||||
# Only refresh MCP toolgroups (those with endpoints)
|
||||
try:
|
||||
tool_group = await self.tool_groups_api.get_tool_group(toolgroup_name)
|
||||
if tool_group.mcp_endpoint:
|
||||
logger.debug(
|
||||
f"Refreshing tools for MCP toolgroup {toolgroup_name} with endpoint {tool_group.mcp_endpoint.uri}"
|
||||
)
|
||||
# Refresh tools for this MCP toolgroup
|
||||
await self.tool_groups_api.refresh_tools(toolgroup_name)
|
||||
else:
|
||||
logger.debug(f"Toolgroup {toolgroup_name} is not an MCP toolgroup, skipping refresh")
|
||||
except Exception as e:
|
||||
# Log but don't fail - tools may become available later
|
||||
logger.warning(f"Failed to refresh MCP tools for toolgroup {toolgroup_name}: {e}")
|
||||
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
||||
|
|
@ -800,6 +832,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
toolgroup_name="client_tools",
|
||||
)
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
|
|
@ -843,6 +876,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
toolgroup_name=toolgroup_name,
|
||||
)
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.tools import (
|
|||
ToolInvocationResult,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
|
@ -44,14 +45,45 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
|||
# this endpoint should be retrieved by getting the tool group right?
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
logger.debug(f"Listing runtime tools for toolgroup {tool_group_id} at endpoint {mcp_endpoint.uri}")
|
||||
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
return await list_mcp_tools(mcp_endpoint.uri, headers)
|
||||
|
||||
if headers:
|
||||
logger.debug(f"Found {len(headers)} headers for MCP endpoint {mcp_endpoint.uri}")
|
||||
# Log header keys but not values for security
|
||||
header_keys = list(headers.keys())
|
||||
logger.debug(f"Header keys: {header_keys}")
|
||||
else:
|
||||
logger.debug(f"No headers found for MCP endpoint {mcp_endpoint.uri}")
|
||||
|
||||
try:
|
||||
result = await list_mcp_tools(mcp_endpoint.uri, headers)
|
||||
logger.info(f"Successfully listed {len(result.data)} tools for toolgroup {tool_group_id}")
|
||||
return result
|
||||
except AuthenticationRequiredError as e:
|
||||
logger.warning(f"Authentication required for MCP endpoint {mcp_endpoint.uri}: {e}")
|
||||
logger.info(
|
||||
f"Returning empty tool list for toolgroup {tool_group_id} - tools will be refreshed when authentication is available"
|
||||
)
|
||||
# Return empty list on authentication errors during startup
|
||||
# Tools will be refreshed when a turn is created with proper auth
|
||||
return ListToolDefsResponse(data=[])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list tools for toolgroup {tool_group_id} at endpoint {mcp_endpoint.uri}: {e}")
|
||||
# Return empty list on other errors too to prevent crashes
|
||||
return ListToolDefsResponse(data=[])
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
if self.tool_store is None:
|
||||
raise ValueError(f"Tool store is not available for tool {tool_name}")
|
||||
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
raise ValueError(f"Tool {tool_name} does not have metadata")
|
||||
endpoint = tool.metadata.get("endpoint")
|
||||
if endpoint is None:
|
||||
raise ValueError(f"Tool {tool_name} does not have an endpoint")
|
||||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue