mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-27 23:40:24 +00:00
several fixes
This commit is contained in:
parent
bf8a73e09a
commit
cddc1f3524
15 changed files with 95 additions and 83 deletions
|
|
@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
elif api == Api.eval:
|
||||
return await p.register_benchmark(obj)
|
||||
elif api == Api.tool_runtime:
|
||||
return await p.register_tool(obj)
|
||||
return await p.register_toolgroup(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
|
@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
elif api == Api.tool_runtime:
|
||||
return await p.unregister_tool(obj.identifier)
|
||||
return await p.unregister_toolgroup(obj.identifier)
|
||||
else:
|
||||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
|
@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
elif isinstance(self, BenchmarksRoutingTable):
|
||||
return ("Eval", "benchmark")
|
||||
elif isinstance(self, ToolGroupsRoutingTable):
|
||||
return ("Tools", "tool")
|
||||
return ("ToolGroups", "tool_group")
|
||||
else:
|
||||
raise ValueError("Unknown routing table type")
|
||||
|
||||
|
|
|
|||
|
|
@ -24,9 +24,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||
tool_name = routing_key
|
||||
if tool_name in self.tool_to_toolgroup:
|
||||
routing_key = self.tool_to_toolgroup[tool_name]
|
||||
if routing_key in self.tool_to_toolgroup:
|
||||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
|
|
@ -39,11 +38,26 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
for toolgroup in toolgroups:
|
||||
group_id = toolgroup.identifier
|
||||
if group_id not in self.toolgroups_to_tools:
|
||||
provider_impl = self.get_provider_impl(toolgroup.provider_id)
|
||||
tools = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
|
||||
provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
|
||||
|
||||
self.toolgroups_to_tools[group_id] = tools.data
|
||||
for tool in tools.data:
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
tooldefs = tooldefs_response.data
|
||||
tools = []
|
||||
for t in tooldefs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=t.name,
|
||||
toolgroup_id=group_id,
|
||||
description=t.description or "",
|
||||
parameters=t.parameters or [],
|
||||
metadata=t.metadata,
|
||||
provider_id=toolgroup.provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
self.toolgroups_to_tools[group_id] = tools
|
||||
for tool in tools:
|
||||
self.tool_to_toolgroup[tool.identifier] = group_id
|
||||
all_tools.extend(self.toolgroups_to_tools[group_id])
|
||||
|
||||
|
|
@ -74,15 +88,15 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
await self.dist_registry.register(
|
||||
ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
mcp_endpoint=mcp_endpoint,
|
||||
args=args,
|
||||
)
|
||||
toolgroup = ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
mcp_endpoint=mcp_endpoint,
|
||||
args=args,
|
||||
)
|
||||
await self.register_object(toolgroup)
|
||||
return toolgroup
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue