mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:09:40 +00:00
simplify toolgroups registration
This commit is contained in:
parent
ba242c04cc
commit
f9a98c278a
15 changed files with 350 additions and 256 deletions
|
|
@ -137,15 +137,15 @@ class Session(BaseModel):
|
|||
memory_bank: Optional[MemoryBank] = None
|
||||
|
||||
|
||||
class AgentToolWithArgs(BaseModel):
|
||||
class AgentToolGroupWithArgs(BaseModel):
|
||||
name: str
|
||||
args: Dict[str, Any]
|
||||
|
||||
|
||||
AgentTool = register_schema(
|
||||
AgentToolGroup = register_schema(
|
||||
Union[
|
||||
str,
|
||||
AgentToolWithArgs,
|
||||
AgentToolGroupWithArgs,
|
||||
],
|
||||
name="AgentTool",
|
||||
)
|
||||
|
|
@ -156,7 +156,7 @@ class AgentConfigCommon(BaseModel):
|
|||
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
tools: Optional[List[AgentTool]] = Field(default_factory=list)
|
||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
|
|
@ -278,7 +278,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
]
|
||||
|
||||
documents: Optional[List[Document]] = None
|
||||
tools: Optional[List[AgentTool]] = None
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
|
@ -317,7 +317,7 @@ class Agents(Protocol):
|
|||
],
|
||||
stream: Optional[bool] = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
tools: Optional[List[AgentTool]] = None,
|
||||
tools: Optional[List[AgentToolGroup]] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/agents/turn/get")
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ class ToolParameter(BaseModel):
|
|||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
required: bool
|
||||
required: bool = Field(default=True)
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class ToolHost(Enum):
|
|||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
tool_group: str
|
||||
toolgroup_id: str
|
||||
tool_host: ToolHost
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
|
|
@ -58,41 +58,19 @@ class ToolDef(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MCPToolGroupDef(BaseModel):
|
||||
"""
|
||||
A tool group that is defined by in a model context protocol server.
|
||||
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
|
||||
"""
|
||||
|
||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
||||
endpoint: URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UserDefinedToolGroupDef(BaseModel):
|
||||
type: Literal["user_defined"] = "user_defined"
|
||||
tools: List[ToolDef]
|
||||
|
||||
|
||||
ToolGroupDef = register_schema(
|
||||
Annotated[
|
||||
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
||||
],
|
||||
name="ToolGroupDef",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolGroupInput(BaseModel):
|
||||
tool_group_id: str
|
||||
tool_group_def: ToolGroupDef
|
||||
provider_id: Optional[str] = None
|
||||
toolgroup_id: str
|
||||
provider_id: str
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolGroup(Resource):
|
||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -104,6 +82,7 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
class ToolStore(Protocol):
|
||||
def get_tool(self, tool_name: str) -> Tool: ...
|
||||
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -112,9 +91,10 @@ class ToolGroups(Protocol):
|
|||
@webmethod(route="/toolgroups/register", method="POST")
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group_def: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
...
|
||||
|
|
@ -122,7 +102,7 @@ class ToolGroups(Protocol):
|
|||
@webmethod(route="/toolgroups/get", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
toolgroup_id: str,
|
||||
) -> ToolGroup: ...
|
||||
|
||||
@webmethod(route="/toolgroups/list", method="GET")
|
||||
|
|
@ -149,8 +129,10 @@ class ToolGroups(Protocol):
|
|||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
|
||||
@webmethod(route="/tool-runtime/discover", method="POST")
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
|
||||
@webmethod(route="/tool-runtime/list-tools", method="POST")
|
||||
async def list_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue