mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
agents to use tools api (#673)
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
parent
596afc6497
commit
a5c57cd381
116 changed files with 4959 additions and 2778 deletions
|
@ -4,10 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
|
@ -21,15 +22,24 @@ class ToolParameter(BaseModel):
|
|||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
required: bool = Field(default=True)
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolHost(Enum):
|
||||
distribution = "distribution"
|
||||
client = "client"
|
||||
model_context_protocol = "model_context_protocol"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
tool_group: str
|
||||
toolgroup_id: str
|
||||
tool_host: ToolHost
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
provider_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
|
@ -39,41 +49,27 @@ class Tool(Resource):
|
|||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
metadata: Dict[str, Any]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[List[ToolParameter]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MCPToolGroupDef(BaseModel):
|
||||
"""
|
||||
A tool group that is defined by in a model context protocol server.
|
||||
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
|
||||
"""
|
||||
|
||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
||||
endpoint: URL
|
||||
class ToolGroupInput(BaseModel):
|
||||
toolgroup_id: str
|
||||
provider_id: str
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UserDefinedToolGroupDef(BaseModel):
|
||||
type: Literal["user_defined"] = "user_defined"
|
||||
tools: List[ToolDef]
|
||||
|
||||
|
||||
ToolGroupDef = register_schema(
|
||||
Annotated[
|
||||
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
||||
],
|
||||
name="ToolGroup",
|
||||
)
|
||||
|
||||
|
||||
class ToolGroup(Resource):
|
||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -85,6 +81,7 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
class ToolStore(Protocol):
|
||||
def get_tool(self, tool_name: str) -> Tool: ...
|
||||
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -93,9 +90,10 @@ class ToolGroups(Protocol):
|
|||
@webmethod(route="/toolgroups/register", method="POST")
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
...
|
||||
|
@ -103,7 +101,7 @@ class ToolGroups(Protocol):
|
|||
@webmethod(route="/toolgroups/get", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
toolgroup_id: str,
|
||||
) -> ToolGroup: ...
|
||||
|
||||
@webmethod(route="/toolgroups/list", method="GET")
|
||||
|
@ -130,8 +128,11 @@ class ToolGroups(Protocol):
|
|||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
|
||||
@webmethod(route="/tool-runtime/discover", method="POST")
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue