mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
address feedback
This commit is contained in:
parent
8bf3f8ea56
commit
ac46bd5eb4
8 changed files with 24 additions and 35 deletions
|
@ -36,7 +36,7 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import MemoryBank
|
from llama_stack.apis.memory import MemoryBank
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import CustomToolDef
|
from llama_stack.apis.tools import UserDefinedToolDef
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,8 +137,8 @@ class AgentConfigCommon(BaseModel):
|
||||||
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
available_tools: Optional[List[str]] = Field(default_factory=list)
|
tool_names: Optional[List[str]] = Field(default_factory=list)
|
||||||
custom_tools: Optional[List[CustomToolDef]] = Field(default_factory=list)
|
client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
|
||||||
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
|
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
|
|
|
@ -48,8 +48,8 @@ class Tool(Resource):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CustomToolDef(BaseModel):
|
class UserDefinedToolDef(BaseModel):
|
||||||
type: Literal["custom"] = "custom"
|
type: Literal["user_defined"] = "user_defined"
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: List[ToolParameter]
|
||||||
|
@ -67,7 +67,7 @@ class BuiltInToolDef(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
ToolDef = register_schema(
|
ToolDef = register_schema(
|
||||||
Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")],
|
Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")],
|
||||||
name="ToolDef",
|
name="ToolDef",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -172,14 +172,3 @@ class ToolRuntime(Protocol):
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
# Three tool types:
|
|
||||||
# 1. Built-in tools
|
|
||||||
# 2. Client tools
|
|
||||||
# 3. Model-context-protocol tools
|
|
||||||
|
|
||||||
# Suport registration of agents with tool groups
|
|
||||||
# TBD: Have a client utility to hide the pre processing tools.
|
|
||||||
# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked.
|
|
||||||
#
|
|
||||||
|
|
|
@ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import (
|
||||||
from llama_stack.apis.shields import Shield, Shields
|
from llama_stack.apis.shields import Shield, Shields
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
BuiltInToolDef,
|
BuiltInToolDef,
|
||||||
CustomToolDef,
|
|
||||||
MCPToolGroupDef,
|
MCPToolGroupDef,
|
||||||
Tool,
|
Tool,
|
||||||
ToolGroup,
|
ToolGroup,
|
||||||
|
@ -36,6 +35,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolGroups,
|
ToolGroups,
|
||||||
ToolHost,
|
ToolHost,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
UserDefinedToolDef,
|
||||||
UserDefinedToolGroupDef,
|
UserDefinedToolGroupDef,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
|
@ -540,7 +540,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
raise ValueError(f"Unknown tool group: {tool_group}")
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
if isinstance(tool_def, CustomToolDef):
|
if isinstance(tool_def, UserDefinedToolDef):
|
||||||
tools.append(
|
tools.append(
|
||||||
Tool(
|
Tool(
|
||||||
identifier=tool_def.name,
|
identifier=tool_def.name,
|
||||||
|
|
|
@ -429,9 +429,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
custom_tools = {}
|
client_tools = {}
|
||||||
for tool in self.agent_config.custom_tools:
|
for tool in self.agent_config.client_tools:
|
||||||
custom_tools[tool.name] = tool
|
client_tools[tool.name] = tool
|
||||||
while True:
|
while True:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -560,7 +560,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
if tool_call.tool_name in custom_tools:
|
if tool_call.tool_name in client_tools:
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -656,7 +656,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def _get_tools(self) -> List[ToolDefinition]:
|
async def _get_tools(self) -> List[ToolDefinition]:
|
||||||
ret = []
|
ret = []
|
||||||
for tool in self.agent_config.custom_tools:
|
for tool in self.agent_config.client_tools:
|
||||||
params = {}
|
params = {}
|
||||||
for param in tool.parameters:
|
for param in tool.parameters:
|
||||||
params[param.name] = ToolParamDefinition(
|
params[param.name] = ToolParamDefinition(
|
||||||
|
@ -672,7 +672,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
parameters=params,
|
parameters=params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for tool_name in self.agent_config.available_tools:
|
for tool_name in self.agent_config.tool_names:
|
||||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||||
if tool.built_in_type:
|
if tool.built_in_type:
|
||||||
ret.append(ToolDefinition(tool_name=tool.built_in_type))
|
ret.append(ToolDefinition(tool_name=tool.built_in_type))
|
||||||
|
|
|
@ -8,13 +8,13 @@ from typing import Any, Dict, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
CustomToolDef,
|
|
||||||
MCPToolGroupDef,
|
MCPToolGroupDef,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolGroupDef,
|
ToolGroupDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
|
UserDefinedToolDef,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tools.append(
|
tools.append(
|
||||||
CustomToolDef(
|
UserDefinedToolDef(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
|
|
|
@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool(
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
**{
|
**{
|
||||||
**common_params,
|
**common_params,
|
||||||
"available_tools": [tool_name],
|
"tool_names": [tool_name],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,9 @@ from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
from llama_stack.apis.models import ModelInput, ModelType
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
BuiltInToolDef,
|
BuiltInToolDef,
|
||||||
CustomToolDef,
|
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
|
UserDefinedToolDef,
|
||||||
UserDefinedToolGroupDef,
|
UserDefinedToolGroupDef,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
@ -50,7 +50,7 @@ def tool_group_input_memory() -> ToolGroupInput:
|
||||||
tool_group_id="memory_group",
|
tool_group_id="memory_group",
|
||||||
tool_group=UserDefinedToolGroupDef(
|
tool_group=UserDefinedToolGroupDef(
|
||||||
tools=[
|
tools=[
|
||||||
CustomToolDef(
|
UserDefinedToolDef(
|
||||||
name="memory",
|
name="memory",
|
||||||
description="Query the memory bank",
|
description="Query the memory bank",
|
||||||
parameters=[
|
parameters=[
|
||||||
|
|
|
@ -151,7 +151,7 @@ def test_agent_simple(llama_stack_client, agent_config):
|
||||||
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"available_tools": [
|
"tool_names": [
|
||||||
"brave_search",
|
"brave_search",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
@ -181,7 +181,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"available_tools": [
|
"tool_names": [
|
||||||
"code_interpreter",
|
"code_interpreter",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
@ -209,12 +209,12 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
"available_tools": ["brave_search"],
|
"tool_names": ["brave_search"],
|
||||||
"custom_tools": [custom_tool.get_tool_definition()],
|
"client_tools": [custom_tool.get_tool_definition()],
|
||||||
"tool_prompt_format": "python_list",
|
"tool_prompt_format": "python_list",
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,))
|
agent = Agent(llama_stack_client, agent_config, client_tools=(custom_tool,))
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue