agents to use tools api (#673)

# What does this PR do?

PR #639 introduced the notion of Tools API and ability to invoke tools
through API just as any resource. This PR changes the Agents to start
using the Tools API to invoke tools. Major changes include:
1) Ability to specify tool groups with AgentConfig
2) Agent gets the corresponding tool definitions for the specified tools
and pass along to the model
3) Attachements are now named as Documents and their behavior is mostly
unchanged from user perspective
4) You can specify args that can be injected to a tool call through
Agent config. This is especially useful in case of memory tool, where
you want the tool to operate on a specific memory bank.
5) You can also register tool groups with args, which lets the agent
inject these as well into the tool call.
6) All tests have been migrated to use new tools API and fixtures
including client SDK tests
7) Telemetry just works with tools API because of our trace protocol
decorator


## Test Plan
```
pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

pytest -s -v -k together  llama_stack/providers/tests/tools/test_tools.py \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py
```
run.yaml:
https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994

Notebook:
https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
Dinesh Yeduguru 2025-01-08 19:01:00 -08:00 committed by GitHub
parent 596afc6497
commit a5c57cd381
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
116 changed files with 4959 additions and 2778 deletions

View file

@ -18,15 +18,11 @@ from typing import (
Union,
)
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
@ -40,166 +36,18 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class Attachment(BaseModel):
content: InterleavedContent | URL
mime_type: str
class AgentTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter"
function_call = "function_call"
memory = "memory"
class ToolDefinitionCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
class SearchEngineType(Enum):
bing = "bing"
brave = "brave"
tavily = "tavily"
@json_schema_type
class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
api_key: str
engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
api_key: str
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
function_name: str
description: str
parameters: Dict[str, ToolParamDefinition]
remote_execution: Optional[RestAPIExecutionConfig] = None
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgentVectorMemoryBankConfig,
AgentKeyValueMemoryBankConfig,
AgentKeywordMemoryBankConfig,
AgentGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 10
AgentToolDefinition = Annotated[
Union[
SearchToolDefinition,
WolframAlphaToolDefinition,
PhotogenToolDefinition,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
MemoryToolDefinition,
],
Field(discriminator="type"),
]
class Document(BaseModel):
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
@ -289,13 +137,27 @@ class Session(BaseModel):
memory_bank: Optional[MemoryBank] = None
class AgentToolGroupWithArgs(BaseModel):
name: str
args: Dict[str, Any]
AgentToolGroup = register_schema(
Union[
str,
AgentToolGroupWithArgs,
],
name="AgentTool",
)
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -340,6 +202,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
AgentTurnResponseEventType.step_complete.value
)
step_type: StepType
step_id: str
step_details: Step
@ -413,7 +276,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
ToolResponseMessage,
]
]
attachments: Optional[List[Attachment]] = None
documents: Optional[List[Document]] = None
toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False
@ -450,8 +315,9 @@ class Agents(Protocol):
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@ -21,15 +22,24 @@ class ToolParameter(BaseModel):
name: str
parameter_type: str
description: str
required: bool = Field(default=True)
default: Optional[Any] = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
tool_group: str
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: List[ToolParameter]
provider_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -39,41 +49,27 @@ class Tool(Resource):
@json_schema_type
class ToolDef(BaseModel):
name: str
description: str
parameters: List[ToolParameter]
metadata: Dict[str, Any]
description: Optional[str] = None
parameters: Optional[List[ToolParameter]] = None
metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@json_schema_type
class MCPToolGroupDef(BaseModel):
"""
A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
endpoint: URL
class ToolGroupInput(BaseModel):
toolgroup_id: str
provider_id: str
args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None
@json_schema_type
class UserDefinedToolGroupDef(BaseModel):
type: Literal["user_defined"] = "user_defined"
tools: List[ToolDef]
ToolGroupDef = register_schema(
Annotated[
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
],
name="ToolGroup",
)
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None
args: Optional[Dict[str, Any]] = None
@json_schema_type
@ -85,6 +81,7 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
@runtime_checkable
@ -93,9 +90,10 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group(
self,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
"""Register a tool group"""
...
@ -103,7 +101,7 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/get", method="GET")
async def get_tool_group(
self,
tool_group_id: str,
toolgroup_id: str,
) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET")
@ -130,8 +128,11 @@ class ToolGroups(Protocol):
class ToolRuntime(Protocol):
tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST")
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(