final changes

This commit is contained in:
Dinesh Yeduguru 2024-12-19 20:59:47 -08:00
parent a297d27d48
commit de065a60f2
11 changed files with 142 additions and 108 deletions

View file

@ -6,21 +6,19 @@
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from pydantic import parse_obj_as
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.tools import * # noqa: F403
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry
def get_impl_api(p: Any) -> Api:
@ -131,7 +129,7 @@ class CommonRoutingTableImpl(RoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task")
elif isinstance(self, ToolsRoutingTable):
elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool")
else:
raise ValueError("Unknown routing table type")
@ -471,65 +469,86 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
await self.register_object(eval_task)
class ToolsRoutingTable(CommonRoutingTableImpl, Tools):
async def list_tools(self) -> List[Tool]:
return await self.get_all_with_type("tool")
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
return tools
async def get_tool(self, tool_id: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_id)
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", tool_group_id)
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group(
self,
tool_group: ToolGroup,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
) -> None:
tools = []
if isinstance(tool_group, MCPToolGroup):
# TODO: Actually find the right MCP provider
if provider_id is None:
raise ValueError("MCP provider_id not specified")
tools = await self.impls_by_provider_id[provider_id].discover_tools(
tool_defs = []
if provider_id is None:
if len(self.impls_by_provider_id.keys()) > 1:
raise ValueError(
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
)
provider_id = list(self.impls_by_provider_id.keys())[0]
if isinstance(tool_group, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group
)
for tool in tools:
tool.provider_id = provider_id
elif isinstance(tool_group, UserDefinedToolGroup):
for tool in tool_group.tools:
tools.append(
Tool(
identifier=tool.name,
tool_group=tool_group.name,
name=tool.name,
description=tool.description,
parameters=tool.parameters,
provider_id=provider_id,
tool_prompt_format=tool.tool_prompt_format,
provider_resource_id=tool.name,
metadata=tool.metadata,
)
)
elif isinstance(tool_group, UserDefinedToolGroupDef):
tool_defs = tool_group.tools
else:
raise ValueError(f"Unknown tool group: {tool_group}")
for tool_def in tool_defs:
tools.append(
Tool(
identifier=tool_def.name,
tool_group=tool_group_id,
description=tool_def.description,
parameters=tool_def.parameters,
provider_id=provider_id,
tool_prompt_format=tool_def.tool_prompt_format,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
# Compare all fields except provider_id since that might be None in new obj
if tool.provider_id is None:
tool.provider_id = existing_tool.provider_id
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.name} already exists in registry. Please use a different identifier."
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
async def unregister_tool(self, tool_id: str) -> None:
tool = await self.get_tool(tool_id)
if tool is None:
raise ValueError(f"Tool {tool_id} not found")
await self.unregister_object(tool)
await self.dist_registry.register(
ToolGroup(
identifier=tool_group_id,
provider_id=provider_id,
provider_resource_id=tool_group_id,
)
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
tool_group = await self.get_tool_group(tool_group_id)
if tool_group is None:
raise ValueError(f"Tool group {tool_group_id} not found")
tools = await self.list_tools(tool_group_id)
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)