simplify toolgroups registration

This commit is contained in:
Dinesh Yeduguru 2025-01-07 15:37:52 -08:00
parent ba242c04cc
commit f9a98c278a
15 changed files with 350 additions and 256 deletions

View file

@ -137,15 +137,15 @@ class Session(BaseModel):
memory_bank: Optional[MemoryBank] = None
class AgentToolWithArgs(BaseModel):
class AgentToolGroupWithArgs(BaseModel):
name: str
args: Dict[str, Any]
AgentTool = register_schema(
AgentToolGroup = register_schema(
Union[
str,
AgentToolWithArgs,
AgentToolGroupWithArgs,
],
name="AgentTool",
)
@ -156,7 +156,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentTool]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
@ -278,7 +278,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
]
documents: Optional[List[Document]] = None
tools: Optional[List[AgentTool]] = None
toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False
@ -317,7 +317,7 @@ class Agents(Protocol):
],
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
tools: Optional[List[AgentTool]] = None,
tools: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")

View file

@ -5,10 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional
from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@ -22,7 +22,7 @@ class ToolParameter(BaseModel):
name: str
parameter_type: str
description: str
required: bool
required: bool = Field(default=True)
default: Optional[Any] = None
@ -36,7 +36,7 @@ class ToolHost(Enum):
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
tool_group: str
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: List[ToolParameter]
@ -58,41 +58,19 @@ class ToolDef(BaseModel):
)
@json_schema_type
class MCPToolGroupDef(BaseModel):
"""
A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
endpoint: URL
@json_schema_type
class UserDefinedToolGroupDef(BaseModel):
type: Literal["user_defined"] = "user_defined"
tools: List[ToolDef]
ToolGroupDef = register_schema(
Annotated[
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
],
name="ToolGroupDef",
)
@json_schema_type
class ToolGroupInput(BaseModel):
tool_group_id: str
tool_group_def: ToolGroupDef
provider_id: Optional[str] = None
toolgroup_id: str
provider_id: str
args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None
@json_schema_type
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None
args: Optional[Dict[str, Any]] = None
@json_schema_type
@ -104,6 +82,7 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
@runtime_checkable
@ -112,9 +91,10 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group(
self,
tool_group_id: str,
tool_group_def: ToolGroupDef,
provider_id: Optional[str] = None,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
"""Register a tool group"""
...
@ -122,7 +102,7 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/get", method="GET")
async def get_tool_group(
self,
tool_group_id: str,
toolgroup_id: str,
) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET")
@ -149,8 +129,10 @@ class ToolGroups(Protocol):
class ToolRuntime(Protocol):
tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST")
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/list-tools", method="POST")
async def list_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(