mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:12:29 +00:00
simplify toolgroups registration
This commit is contained in:
parent
ba242c04cc
commit
f9a98c278a
15 changed files with 350 additions and 256 deletions
|
|
@ -26,15 +26,7 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield, Shields
|
||||
from llama_stack.apis.tools import (
|
||||
MCPToolGroupDef,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ToolGroupDef,
|
||||
ToolGroups,
|
||||
ToolHost,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
|
|
@ -496,51 +488,38 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
tools = await self.get_all_with_type("tool")
|
||||
if tool_group_id:
|
||||
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
|
||||
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
|
||||
return tools
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return await self.get_all_with_type("tool_group")
|
||||
|
||||
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", tool_group_id)
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return await self.get_object_by_identifier("tool", tool_name)
|
||||
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group_def: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = []
|
||||
tool_host = ToolHost.distribution
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id.keys()) > 1:
|
||||
raise ValueError(
|
||||
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
|
||||
)
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
# parse tool group to the type if dict
|
||||
tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def)
|
||||
if isinstance(tool_group_def, MCPToolGroupDef):
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||
tool_group_def
|
||||
)
|
||||
tool_host = ToolHost.model_context_protocol
|
||||
elif isinstance(tool_group_def, UserDefinedToolGroupDef):
|
||||
tool_defs = tool_group_def.tools
|
||||
else:
|
||||
raise ValueError(f"Unknown tool group: {tool_group_def}")
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_tools(
|
||||
toolgroup_id, mcp_endpoint
|
||||
)
|
||||
tool_host = (
|
||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
)
|
||||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
parameters=tool_def.parameters or [],
|
||||
provider_id=provider_id,
|
||||
|
|
@ -565,9 +544,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
identifier=tool_group_id,
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=tool_group_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
mcp_endpoint=mcp_endpoint,
|
||||
args=args,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue