mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:12:29 +00:00
rename UserDefinedToolDef to ToolDef
This commit is contained in:
parent
db0b2a60c1
commit
e3775eb6f6
8 changed files with 180 additions and 322 deletions
|
|
@ -27,15 +27,12 @@ from llama_stack.apis.scoring_functions import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield, Shields
|
||||
from llama_stack.apis.tools import (
|
||||
BuiltInToolDef,
|
||||
MCPToolGroupDef,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ToolGroupDef,
|
||||
ToolGroups,
|
||||
ToolHost,
|
||||
ToolPromptFormat,
|
||||
UserDefinedToolDef,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
|
|
@ -514,7 +511,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group: ToolGroupDef,
|
||||
tool_group_def: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
|
|
@ -528,47 +525,31 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
# parse tool group to the type if dict
|
||||
tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group)
|
||||
if isinstance(tool_group, MCPToolGroupDef):
|
||||
tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def)
|
||||
if isinstance(tool_group_def, MCPToolGroupDef):
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||
tool_group
|
||||
tool_group_def
|
||||
)
|
||||
tool_host = ToolHost.model_context_protocol
|
||||
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
||||
tool_defs = tool_group.tools
|
||||
elif isinstance(tool_group_def, UserDefinedToolGroupDef):
|
||||
tool_defs = tool_group_def.tools
|
||||
else:
|
||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
||||
raise ValueError(f"Unknown tool group: {tool_group_def}")
|
||||
|
||||
for tool_def in tool_defs:
|
||||
if isinstance(tool_def, UserDefinedToolDef):
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
description=tool_def.description,
|
||||
parameters=tool_def.parameters,
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool_def.tool_prompt_format,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
)
|
||||
elif isinstance(tool_def, BuiltInToolDef):
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.built_in_type.value,
|
||||
tool_group=tool_group_id,
|
||||
built_in_type=tool_def.built_in_type,
|
||||
description="",
|
||||
parameters=[],
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
provider_resource_id=tool_def.built_in_type.value,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
description=tool_def.description or "",
|
||||
parameters=tool_def.parameters or [],
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool_def.tool_prompt_format,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
)
|
||||
for tool in tools:
|
||||
existing_tool = await self.get_tool(tool.identifier)
|
||||
# Compare existing and new object if one exists
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue