address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-30 15:47:01 -08:00
parent 8bf3f8ea56
commit ac46bd5eb4
8 changed files with 24 additions and 35 deletions

View file

@ -36,7 +36,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.memory import MemoryBank from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import CustomToolDef from llama_stack.apis.tools import UserDefinedToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -137,8 +137,8 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list) input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list)
available_tools: Optional[List[str]] = Field(default_factory=list) tool_names: Optional[List[str]] = Field(default_factory=list)
custom_tools: Optional[List[CustomToolDef]] = Field(default_factory=list) client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
preprocessing_tools: Optional[List[str]] = Field(default_factory=list) preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(

View file

@ -48,8 +48,8 @@ class Tool(Resource):
@json_schema_type @json_schema_type
class CustomToolDef(BaseModel): class UserDefinedToolDef(BaseModel):
type: Literal["custom"] = "custom" type: Literal["user_defined"] = "user_defined"
name: str name: str
description: str description: str
parameters: List[ToolParameter] parameters: List[ToolParameter]
@ -67,7 +67,7 @@ class BuiltInToolDef(BaseModel):
ToolDef = register_schema( ToolDef = register_schema(
Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")], Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")],
name="ToolDef", name="ToolDef",
) )
@ -172,14 +172,3 @@ class ToolRuntime(Protocol):
) -> ToolInvocationResult: ) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments"""
... ...
# Three tool types:
# 1. Built-in tools
# 2. Client tools
# 3. Model-context-protocol tools
# Suport registration of agents with tool groups
# TBD: Have a client utility to hide the pre processing tools.
# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked.
#

View file

@ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import (
from llama_stack.apis.shields import Shield, Shields from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
BuiltInToolDef, BuiltInToolDef,
CustomToolDef,
MCPToolGroupDef, MCPToolGroupDef,
Tool, Tool,
ToolGroup, ToolGroup,
@ -36,6 +35,7 @@ from llama_stack.apis.tools import (
ToolGroups, ToolGroups,
ToolHost, ToolHost,
ToolPromptFormat, ToolPromptFormat,
UserDefinedToolDef,
UserDefinedToolGroupDef, UserDefinedToolGroupDef,
) )
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
@ -540,7 +540,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
raise ValueError(f"Unknown tool group: {tool_group}") raise ValueError(f"Unknown tool group: {tool_group}")
for tool_def in tool_defs: for tool_def in tool_defs:
if isinstance(tool_def, CustomToolDef): if isinstance(tool_def, UserDefinedToolDef):
tools.append( tools.append(
Tool( Tool(
identifier=tool_def.name, identifier=tool_def.name,

View file

@ -429,9 +429,9 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0 n_iter = 0
# Build a map of custom tools to their definitions for faster lookup # Build a map of custom tools to their definitions for faster lookup
custom_tools = {} client_tools = {}
for tool in self.agent_config.custom_tools: for tool in self.agent_config.client_tools:
custom_tools[tool.name] = tool client_tools[tool.name] = tool
while True: while True:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -560,7 +560,7 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
log.info(f"{str(message)}") log.info(f"{str(message)}")
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
if tool_call.tool_name in custom_tools: if tool_call.tool_name in client_tools:
yield message yield message
return return
@ -656,7 +656,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _get_tools(self) -> List[ToolDefinition]: async def _get_tools(self) -> List[ToolDefinition]:
ret = [] ret = []
for tool in self.agent_config.custom_tools: for tool in self.agent_config.client_tools:
params = {} params = {}
for param in tool.parameters: for param in tool.parameters:
params[param.name] = ToolParamDefinition( params[param.name] = ToolParamDefinition(
@ -672,7 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
parameters=params, parameters=params,
) )
) )
for tool_name in self.agent_config.available_tools: for tool_name in self.agent_config.tool_names:
tool = await self.tool_groups_api.get_tool(tool_name) tool = await self.tool_groups_api.get_tool(tool_name)
if tool.built_in_type: if tool.built_in_type:
ret.append(ToolDefinition(tool_name=tool.built_in_type)) ret.append(ToolDefinition(tool_name=tool.built_in_type))

View file

@ -8,13 +8,13 @@ from typing import Any, Dict, List
from urllib.parse import urlparse from urllib.parse import urlparse
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
CustomToolDef,
MCPToolGroupDef, MCPToolGroupDef,
ToolDef, ToolDef,
ToolGroupDef, ToolGroupDef,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
UserDefinedToolDef,
) )
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -53,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
) )
) )
tools.append( tools.append(
CustomToolDef( UserDefinedToolDef(
name=tool.name, name=tool.name,
description=tool.description, description=tool.description,
parameters=parameters, parameters=parameters,

View file

@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool(
agent_config = AgentConfig( agent_config = AgentConfig(
**{ **{
**common_params, **common_params,
"available_tools": [tool_name], "tool_names": [tool_name],
} }
) )

View file

@ -13,9 +13,9 @@ from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.models import ModelInput, ModelType from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
BuiltInToolDef, BuiltInToolDef,
CustomToolDef,
ToolGroupInput, ToolGroupInput,
ToolParameter, ToolParameter,
UserDefinedToolDef,
UserDefinedToolGroupDef, UserDefinedToolGroupDef,
) )
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
@ -50,7 +50,7 @@ def tool_group_input_memory() -> ToolGroupInput:
tool_group_id="memory_group", tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef( tool_group=UserDefinedToolGroupDef(
tools=[ tools=[
CustomToolDef( UserDefinedToolDef(
name="memory", name="memory",
description="Query the memory bank", description="Query the memory bank",
parameters=[ parameters=[

View file

@ -151,7 +151,7 @@ def test_agent_simple(llama_stack_client, agent_config):
def test_builtin_tool_brave_search(llama_stack_client, agent_config): def test_builtin_tool_brave_search(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"available_tools": [ "tool_names": [
"brave_search", "brave_search",
], ],
} }
@ -181,7 +181,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
def test_builtin_tool_code_execution(llama_stack_client, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"available_tools": [ "tool_names": [
"code_interpreter", "code_interpreter",
], ],
} }
@ -209,12 +209,12 @@ def test_custom_tool(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct", "model": "meta-llama/Llama-3.2-3B-Instruct",
"available_tools": ["brave_search"], "tool_names": ["brave_search"],
"custom_tools": [custom_tool.get_tool_definition()], "client_tools": [custom_tool.get_tool_definition()],
"tool_prompt_format": "python_list", "tool_prompt_format": "python_list",
} }
agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,)) agent = Agent(llama_stack_client, agent_config, client_tools=(custom_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(