several fixes

This commit is contained in:
Ashwin Bharambe 2025-05-25 10:35:48 -07:00
parent bf8a73e09a
commit cddc1f3524
15 changed files with 95 additions and 83 deletions

View file

@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolsProtocolPrivate,
ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate,
)
@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),

View file

@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
ListToolsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
@ -87,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
return await self.routing_table.list_tools(tool_group_id)

View file

@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_tool(obj)
return await p.register_toolgroup(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_tool(obj.identifier)
return await p.unregister_toolgroup(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable):
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool")
return ("ToolGroups", "tool_group")
else:
raise ValueError("Unknown routing table type")

View file

@ -24,9 +24,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
tool_name = routing_key
if tool_name in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[tool_name]
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
@ -39,11 +38,26 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for toolgroup in toolgroups:
group_id = toolgroup.identifier
if group_id not in self.toolgroups_to_tools:
provider_impl = self.get_provider_impl(toolgroup.provider_id)
tools = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
self.toolgroups_to_tools[group_id] = tools.data
for tool in tools.data:
# TODO: kill this Tool vs ToolDef distinction
tooldefs = tooldefs_response.data
tools = []
for t in tooldefs:
tools.append(
Tool(
identifier=t.name,
toolgroup_id=group_id,
description=t.description or "",
parameters=t.parameters or [],
metadata=t.metadata,
provider_id=toolgroup.provider_id,
)
)
self.toolgroups_to_tools[group_id] = tools
for tool in tools:
self.tool_to_toolgroup[tool.identifier] = group_id
all_tools.extend(self.toolgroups_to_tools[group_id])
@ -74,15 +88,15 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
await self.dist_registry.register(
ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
toolgroup = ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
await self.register_object(toolgroup)
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)