forked from phoenix-oss/llama-stack-mirror
fix(tools): do not index tools, only index toolgroups (#2261)
When registering a MCP endpoint, we cannot list tools (like we used to) since the MCP endpoint may be behind an auth wall. Registration can happen much sooner (via run.yaml). Instead, we do listing only when the _user_ actually calls listing. Furthermore, we cache the list in-memory in the server. Currently, the cache is not invalidated -- we may want to periodically re-list for MCP servers. Note that they must call `list_tools` before calling `invoke_tool` -- we use this critically. This will enable us to list MCP servers in run.yaml ## Test Plan Existing tests, updated tests accordingly.
This commit is contained in:
parent
5a422e236c
commit
ce33d02443
19 changed files with 131 additions and 153 deletions
13
docs/_static/llama-stack-spec.html
vendored
13
docs/_static/llama-stack-spec.html
vendored
|
@ -9555,9 +9555,6 @@
|
||||||
"toolgroup_id": {
|
"toolgroup_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"tool_host": {
|
|
||||||
"$ref": "#/components/schemas/ToolHost"
|
|
||||||
},
|
|
||||||
"description": {
|
"description": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
@ -9599,21 +9596,11 @@
|
||||||
"provider_id",
|
"provider_id",
|
||||||
"type",
|
"type",
|
||||||
"toolgroup_id",
|
"toolgroup_id",
|
||||||
"tool_host",
|
|
||||||
"description",
|
"description",
|
||||||
"parameters"
|
"parameters"
|
||||||
],
|
],
|
||||||
"title": "Tool"
|
"title": "Tool"
|
||||||
},
|
},
|
||||||
"ToolHost": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"distribution",
|
|
||||||
"client",
|
|
||||||
"model_context_protocol"
|
|
||||||
],
|
|
||||||
"title": "ToolHost"
|
|
||||||
},
|
|
||||||
"ToolGroup": {
|
"ToolGroup": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
10
docs/_static/llama-stack-spec.yaml
vendored
10
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6713,8 +6713,6 @@ components:
|
||||||
default: tool
|
default: tool
|
||||||
toolgroup_id:
|
toolgroup_id:
|
||||||
type: string
|
type: string
|
||||||
tool_host:
|
|
||||||
$ref: '#/components/schemas/ToolHost'
|
|
||||||
description:
|
description:
|
||||||
type: string
|
type: string
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -6737,17 +6735,9 @@ components:
|
||||||
- provider_id
|
- provider_id
|
||||||
- type
|
- type
|
||||||
- toolgroup_id
|
- toolgroup_id
|
||||||
- tool_host
|
|
||||||
- description
|
- description
|
||||||
- parameters
|
- parameters
|
||||||
title: Tool
|
title: Tool
|
||||||
ToolHost:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- distribution
|
|
||||||
- client
|
|
||||||
- model_context_protocol
|
|
||||||
title: ToolHost
|
|
||||||
ToolGroup:
|
ToolGroup:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
|
||||||
default: Any | None = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolHost(Enum):
|
|
||||||
distribution = "distribution"
|
|
||||||
client = "client"
|
|
||||||
model_context_protocol = "model_context_protocol"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
tool_host: ToolHost
|
|
||||||
description: str
|
description: str
|
||||||
parameters: list[ToolParameter]
|
parameters: list[ToolParameter]
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
|
@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
|
||||||
RemoteProviderSpec,
|
RemoteProviderSpec,
|
||||||
ScoringFunctionsProtocolPrivate,
|
ScoringFunctionsProtocolPrivate,
|
||||||
ShieldsProtocolPrivate,
|
ShieldsProtocolPrivate,
|
||||||
ToolsProtocolPrivate,
|
ToolGroupsProtocolPrivate,
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||||
def additional_protocols_map() -> dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolsResponse,
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
RAGQueryResult,
|
RAGQueryResult,
|
||||||
|
@ -19,7 +19,8 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
class RagToolImpl(RAGToolRuntime):
|
class RagToolImpl(RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: ToolGroupsRoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: ToolGroupsRoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing ToolRuntimeRouter")
|
logger.debug("Initializing ToolRuntimeRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
@ -86,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolsResponse:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.list_tools(tool_group_id)
|
||||||
|
|
|
@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
return await p.register_benchmark(obj)
|
return await p.register_benchmark(obj)
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
return await p.register_tool(obj)
|
return await p.register_toolgroup(obj)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
return await p.unregister_dataset(obj.identifier)
|
return await p.unregister_dataset(obj.identifier)
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
return await p.unregister_tool(obj.identifier)
|
return await p.unregister_toolgroup(obj.identifier)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
elif isinstance(self, BenchmarksRoutingTable):
|
elif isinstance(self, BenchmarksRoutingTable):
|
||||||
return ("Eval", "benchmark")
|
return ("Eval", "benchmark")
|
||||||
elif isinstance(self, ToolGroupsRoutingTable):
|
elif isinstance(self, ToolGroupsRoutingTable):
|
||||||
return ("Tools", "tool")
|
return ("ToolGroups", "tool_group")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown routing table type")
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
||||||
ToolGroupWithACL,
|
|
||||||
ToolWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
@ -20,11 +17,51 @@ logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||||
|
tool_to_toolgroup: dict[str, str] = {}
|
||||||
|
|
||||||
|
# overridden
|
||||||
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
|
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||||
|
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||||
|
if routing_key in self.tool_to_toolgroup:
|
||||||
|
routing_key = self.tool_to_toolgroup[routing_key]
|
||||||
|
return super().get_provider_impl(routing_key, provider_id)
|
||||||
|
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
tools = await self.get_all_with_type("tool")
|
|
||||||
if toolgroup_id:
|
if toolgroup_id:
|
||||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
||||||
return ListToolsResponse(data=tools)
|
else:
|
||||||
|
toolgroups = await self.get_all_with_type("tool_group")
|
||||||
|
|
||||||
|
all_tools = []
|
||||||
|
for toolgroup in toolgroups:
|
||||||
|
group_id = toolgroup.identifier
|
||||||
|
if group_id not in self.toolgroups_to_tools:
|
||||||
|
provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id)
|
||||||
|
tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
|
||||||
|
|
||||||
|
# TODO: kill this Tool vs ToolDef distinction
|
||||||
|
tooldefs = tooldefs_response.data
|
||||||
|
tools = []
|
||||||
|
for t in tooldefs:
|
||||||
|
tools.append(
|
||||||
|
Tool(
|
||||||
|
identifier=t.name,
|
||||||
|
toolgroup_id=group_id,
|
||||||
|
description=t.description or "",
|
||||||
|
parameters=t.parameters or [],
|
||||||
|
metadata=t.metadata,
|
||||||
|
provider_id=toolgroup.provider_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.toolgroups_to_tools[group_id] = tools
|
||||||
|
for tool in tools:
|
||||||
|
self.tool_to_toolgroup[tool.identifier] = group_id
|
||||||
|
all_tools.extend(self.toolgroups_to_tools[group_id])
|
||||||
|
|
||||||
|
return ListToolsResponse(data=all_tools)
|
||||||
|
|
||||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||||
|
@ -36,7 +73,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
return tool_group
|
return tool_group
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
return await self.get_object_by_identifier("tool", tool_name)
|
if tool_name in self.tool_to_toolgroup:
|
||||||
|
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||||
|
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||||
|
for tool in tools:
|
||||||
|
if tool.identifier == tool_name:
|
||||||
|
return tool
|
||||||
|
raise ValueError(f"Tool '{tool_name}' not found")
|
||||||
|
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
|
@ -45,53 +88,20 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: dict[str, Any] | None = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
toolgroup = ToolGroupWithACL(
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
identifier=toolgroup_id,
|
||||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=toolgroup_id,
|
||||||
for tool_def in tool_defs.data:
|
mcp_endpoint=mcp_endpoint,
|
||||||
tools.append(
|
args=args,
|
||||||
ToolWithACL(
|
|
||||||
identifier=tool_def.name,
|
|
||||||
toolgroup_id=toolgroup_id,
|
|
||||||
description=tool_def.description or "",
|
|
||||||
parameters=tool_def.parameters or [],
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=tool_def.name,
|
|
||||||
metadata=tool_def.metadata,
|
|
||||||
tool_host=tool_host,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for tool in tools:
|
|
||||||
existing_tool = await self.get_tool(tool.identifier)
|
|
||||||
# Compare existing and new object if one exists
|
|
||||||
if existing_tool:
|
|
||||||
existing_dict = existing_tool.model_dump()
|
|
||||||
new_dict = tool.model_dump()
|
|
||||||
|
|
||||||
if existing_dict != new_dict:
|
|
||||||
raise ValueError(
|
|
||||||
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
|
||||||
)
|
|
||||||
await self.register_object(tool)
|
|
||||||
|
|
||||||
await self.dist_registry.register(
|
|
||||||
ToolGroupWithACL(
|
|
||||||
identifier=toolgroup_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=toolgroup_id,
|
|
||||||
mcp_endpoint=mcp_endpoint,
|
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
await self.register_object(toolgroup)
|
||||||
|
return toolgroup
|
||||||
|
|
||||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
if tool_group is None:
|
if tool_group is None:
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
tools = await self.list_tools(toolgroup_id)
|
|
||||||
for tool in getattr(tools, "data", []):
|
|
||||||
await self.unregister_object(tool)
|
|
||||||
await self.unregister_object(tool_group)
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v8"
|
KEY_VERSION = "v9"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import Tool
|
from llama_stack.apis.tools import ToolGroup
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
|
||||||
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ToolsProtocolPrivate(Protocol):
|
class ToolGroupsProtocolPrivate(Protocol):
|
||||||
async def register_tool(self, tool: Tool) -> None: ...
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None: ...
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
RAGQueryResult,
|
RAGQueryResult,
|
||||||
RAGToolRuntime,
|
RAGToolRuntime,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
content_from_doc,
|
content_from_doc,
|
||||||
|
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RagToolRuntimeConfig,
|
config: RagToolRuntimeConfig,
|
||||||
|
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import BingSearchToolConfig
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: BingSearchToolConfig):
|
def __init__(self, config: BingSearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -11,30 +11,30 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import BraveSearchToolConfig
|
from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: BraveSearchToolConfig):
|
def __init__(self, config: BraveSearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -11,12 +11,13 @@ from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
from .config import MCPProviderConfig
|
from .config import MCPProviderConfig
|
||||||
|
@ -24,13 +25,19 @@ from .config import MCPProviderConfig
|
||||||
logger = get_logger(__name__, category="tools")
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolDefsResponse:
|
||||||
|
|
|
@ -12,29 +12,29 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import TavilySearchToolConfig
|
from .config import TavilySearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: TavilySearchToolConfig):
|
def __init__(self, config: TavilySearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import WolframAlphaToolConfig
|
from .config import WolframAlphaToolConfig
|
||||||
|
|
||||||
|
|
||||||
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: WolframAlphaToolConfig):
|
def __init__(self, config: WolframAlphaToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.url = "https://api.wolframalpha.com/v2/query"
|
self.url = "https://api.wolframalpha.com/v2/query"
|
||||||
|
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
|
||||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools = llama_stack_client.tool_runtime.list_tools()
|
||||||
|
assert any(tool.identifier == "web_search" for tool in tools)
|
||||||
|
|
||||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||||
tool_name="web_search", kwargs={"query": sample_search_query}
|
tool_name="web_search", kwargs={"query": sample_search_query}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the response
|
# Verify the response
|
||||||
assert response.content is not None
|
assert response.content is not None
|
||||||
assert len(response.content) > 0
|
assert len(response.content) > 0
|
||||||
|
@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
|
||||||
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||||
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools = llama_stack_client.tool_runtime.list_tools()
|
||||||
|
assert any(tool.identifier == "wolfram_alpha" for tool in tools)
|
||||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||||
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response.content)
|
|
||||||
assert response.content is not None
|
assert response.content is not None
|
||||||
assert len(response.content) > 0
|
assert len(response.content) > 0
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
|
@ -31,13 +31,12 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
test_toolgroup_id = MCP_TOOLGROUP_ID
|
test_toolgroup_id = MCP_TOOLGROUP_ID
|
||||||
uri = mcp_server["server_url"]
|
uri = mcp_server["server_url"]
|
||||||
|
|
||||||
# registering itself should fail since it requires listing tools
|
# registering should not raise an error anymore even if you don't specify the auth token
|
||||||
with pytest.raises(Exception, match="Unauthorized"):
|
llama_stack_client.toolgroups.register(
|
||||||
llama_stack_client.toolgroups.register(
|
toolgroup_id=test_toolgroup_id,
|
||||||
toolgroup_id=test_toolgroup_id,
|
provider_id="model-context-protocol",
|
||||||
provider_id="model-context-protocol",
|
mcp_endpoint=dict(uri=uri),
|
||||||
mcp_endpoint=dict(uri=uri),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
provider_data = {
|
provider_data = {
|
||||||
"mcp_headers": {
|
"mcp_headers": {
|
||||||
|
@ -50,18 +49,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
with pytest.raises(Exception, match="Unauthorized"):
|
||||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
|
llama_stack_client.tools.list()
|
||||||
except Exception as e:
|
|
||||||
# An error is OK since the toolgroup may not exist
|
|
||||||
print(f"Error unregistering toolgroup: {e}")
|
|
||||||
|
|
||||||
llama_stack_client.toolgroups.register(
|
|
||||||
toolgroup_id=test_toolgroup_id,
|
|
||||||
provider_id="model-context-protocol",
|
|
||||||
mcp_endpoint=dict(uri=uri),
|
|
||||||
extra_headers=auth_headers,
|
|
||||||
)
|
|
||||||
response = llama_stack_client.tools.list(
|
response = llama_stack_client.tools.list(
|
||||||
toolgroup_id=test_toolgroup_id,
|
toolgroup_id=test_toolgroup_id,
|
||||||
extra_headers=auth_headers,
|
extra_headers=auth_headers,
|
||||||
|
|
|
@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
|
||||||
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||||
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
# Verify tools are also unregistered
|
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||||
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||||
assert isinstance(unregister_tools_list_response, list)
|
|
||||||
assert not unregister_tools_list_response
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models.models import Model, ModelType
|
from llama_stack.apis.models.models import Model, ModelType
|
||||||
from llama_stack.apis.shields.shields import Shield
|
from llama_stack.apis.shields.shields import Shield
|
||||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
|
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||||
|
@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(Api.tool_runtime)
|
super().__init__(Api.tool_runtime)
|
||||||
|
|
||||||
async def register_tool(self, tool):
|
async def register_toolgroup(self, toolgroup: ToolGroup):
|
||||||
return tool
|
return toolgroup
|
||||||
|
|
||||||
async def unregister_tool(self, tool_name: str):
|
async def unregister_toolgroup(self, toolgroup_id: str):
|
||||||
return tool_name
|
return toolgroup_id
|
||||||
|
|
||||||
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
|
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
|
||||||
return ListToolDefsResponse(
|
return ListToolDefsResponse(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue