diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index d116b1448..33ca52363 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -3705,10 +3705,10 @@
"type": "string"
}
},
- "tool_names": {
+ "tools": {
"type": "array",
"items": {
- "type": "string"
+ "$ref": "#/components/schemas/AgentTool"
}
},
"client_tools": {
@@ -3717,12 +3717,6 @@
"$ref": "#/components/schemas/UserDefinedToolDef"
}
},
- "preprocessing_tools": {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
"tool_choice": {
"$ref": "#/components/schemas/ToolChoice",
"default": "auto"
@@ -3753,6 +3747,51 @@
"enable_session_persistence"
]
},
+ "AgentTool": {
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string"
+ },
+ "args": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "name",
+ "args"
+ ]
+ }
+ ]
+ },
"ToolParameter": {
"type": "object",
"properties": {
@@ -3934,6 +3973,12 @@
},
"stream": {
"type": "boolean"
+ },
+ "tools": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/AgentTool"
+ }
}
},
"additionalProperties": false,
@@ -7944,6 +7989,10 @@
"name": "AgentStepResponse",
"description": ""
},
+ {
+ "name": "AgentTool",
+ "description": ""
+ },
{
"name": "AgentTurnResponseEvent",
"description": "Streamed agent execution response.\n\n"
@@ -8691,6 +8740,7 @@
"AgentCreateResponse",
"AgentSessionCreateResponse",
"AgentStepResponse",
+ "AgentTool",
"AgentTurnResponseEvent",
"AgentTurnResponseStepCompletePayload",
"AgentTurnResponseStepProgressPayload",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index c1097107e..4da311cf0 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -38,22 +38,18 @@ components:
items:
type: string
type: array
- preprocessing_tools:
- items:
- type: string
- type: array
sampling_params:
$ref: '#/components/schemas/SamplingParams'
tool_choice:
$ref: '#/components/schemas/ToolChoice'
default: auto
- tool_names:
- items:
- type: string
- type: array
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
+ tools:
+ items:
+ $ref: '#/components/schemas/AgentTool'
+ type: array
required:
- max_infer_iters
- model
@@ -88,6 +84,27 @@ components:
required:
- step
type: object
+ AgentTool:
+ oneOf:
+ - type: string
+ - additionalProperties: false
+ properties:
+ args:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ name:
+ type: string
+ required:
+ - name
+ - args
+ type: object
AgentTurnResponseEvent:
additionalProperties: false
properties:
@@ -611,6 +628,10 @@ components:
type: string
stream:
type: boolean
+ tools:
+ items:
+ $ref: '#/components/schemas/AgentTool'
+ type: array
required:
- agent_id
- session_id
@@ -4726,6 +4747,8 @@ tags:
- description:
name: AgentStepResponse
+- description:
+ name: AgentTool
- description: 'Streamed agent execution response.
@@ -5257,6 +5280,7 @@ x-tagGroups:
- AgentCreateResponse
- AgentSessionCreateResponse
- AgentStepResponse
+ - AgentTool
- AgentTurnResponseEvent
- AgentTurnResponseStepCompletePayload
- AgentTurnResponseStepProgressPayload
diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py
index 09184d09a..18bbcd95c 100644
--- a/llama_stack/apis/agents/agents.py
+++ b/llama_stack/apis/agents/agents.py
@@ -18,7 +18,7 @@ from typing import (
Union,
)
-from llama_models.schema_utils import json_schema_type, webmethod
+from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
@@ -132,14 +132,27 @@ class Session(BaseModel):
memory_bank: Optional[MemoryBank] = None
+class AgentToolWithArgs(BaseModel):
+ name: str
+ args: Dict[str, Any]
+
+
+AgentTool = register_schema(
+ Union[
+ str,
+ AgentToolWithArgs,
+ ],
+ name="AgentTool",
+)
+
+
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
- tool_names: Optional[List[str]] = Field(default_factory=list)
+ tools: Optional[List[AgentTool]] = Field(default_factory=list)
client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
- preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@@ -295,6 +308,7 @@ class Agents(Protocol):
]
],
stream: Optional[bool] = False,
+ tools: Optional[List[AgentTool]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index b035ac098..700fa565c 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -13,7 +13,7 @@ import secrets
import string
import uuid
from datetime import datetime
-from typing import AsyncGenerator, List, Optional
+from typing import AsyncGenerator, Dict, List, Optional
from urllib.parse import urlparse
import httpx
@@ -21,6 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
from llama_stack.apis.agents import (
AgentConfig,
+ AgentTool,
+ AgentToolWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@@ -188,6 +190,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
+ tools_for_turn=request.tools,
):
if isinstance(chunk, CompletionMessage):
log.info(
@@ -237,6 +240,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
sampling_params: SamplingParams,
stream: bool = False,
+ tools_for_turn: Optional[List[AgentTool]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@@ -253,7 +257,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
- session_id, turn_id, input_messages, sampling_params, stream
+ session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn
):
if isinstance(res, bool):
return
@@ -348,82 +352,90 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
sampling_params: SamplingParams,
stream: bool = False,
+ tools_for_turn: Optional[List[AgentTool]] = None,
) -> AsyncGenerator:
- if self.agent_config.preprocessing_tools:
- with tracing.span("preprocessing_tools") as span:
- for tool_name in self.agent_config.preprocessing_tools:
- step_id = str(uuid.uuid4())
- yield AgentTurnResponseStreamChunk(
- event=AgentTurnResponseEvent(
- payload=AgentTurnResponseStepStartPayload(
- step_type=StepType.tool_execution.value,
- step_id=step_id,
- )
- )
- )
- args = dict(
- session_id=session_id,
- turn_id=turn_id,
- input_messages=input_messages,
- )
- yield AgentTurnResponseStreamChunk(
- event=AgentTurnResponseEvent(
- payload=AgentTurnResponseStepProgressPayload(
- step_type=StepType.tool_execution.value,
- step_id=step_id,
- tool_call_delta=ToolCallDelta(
- parse_status=ToolCallParseStatus.success,
- content=ToolCall(
- call_id="", tool_name=tool_name, arguments={}
- ),
- ),
- )
- )
- )
- result = await self.tool_runtime_api.invoke_tool(
- tool_name=tool_name,
- args=args,
- )
+ tool_args = {}
+ if tools_for_turn:
+ for tool in tools_for_turn:
+ if isinstance(tool, AgentToolWithArgs):
+ tool_args[tool.name] = tool.args
- yield AgentTurnResponseStreamChunk(
- event=AgentTurnResponseEvent(
- payload=AgentTurnResponseStepCompletePayload(
- step_type=StepType.tool_execution.value,
- step_id=step_id,
- step_details=ToolExecutionStep(
- step_id=step_id,
- turn_id=turn_id,
- tool_calls=[
- ToolCall(
- call_id="",
- tool_name=tool_name,
- arguments={},
- )
- ],
- tool_responses=[
- ToolResponse(
- call_id="",
- tool_name=tool_name,
- content=result.content,
- )
- ],
- ),
- )
+ tool_defs = await self._get_tool_defs(tools_for_turn)
+ if "memory" in tool_defs and len(input_messages) > 0:
+ with tracing.span("memory_tool") as span:
+ step_id = str(uuid.uuid4())
+ yield AgentTurnResponseStreamChunk(
+ event=AgentTurnResponseEvent(
+ payload=AgentTurnResponseStepStartPayload(
+ step_type=StepType.tool_execution.value,
+ step_id=step_id,
)
)
- span.set_attribute(
- "input", [m.model_dump_json() for m in input_messages]
+ )
+ extra_args = tool_args.get("memory", {})
+ args = {
+ # Query memory with the last message's content
+ "query": input_messages[-1],
+ **extra_args,
+ }
+ serialized_args = tracing.serialize_value(args)
+ yield AgentTurnResponseStreamChunk(
+ event=AgentTurnResponseEvent(
+ payload=AgentTurnResponseStepProgressPayload(
+ step_type=StepType.tool_execution.value,
+ step_id=step_id,
+ tool_call_delta=ToolCallDelta(
+ parse_status=ToolCallParseStatus.success,
+ content=ToolCall(
+ call_id="",
+ tool_name="memory",
+ arguments=serialized_args,
+ ),
+ ),
+ )
)
- span.set_attribute("output", result.content)
- span.set_attribute("error_code", result.error_code)
- span.set_attribute("error_message", result.error_message)
- if isinstance(tool_name, BuiltinTool):
- span.set_attribute("tool_name", tool_name.value)
- else:
- span.set_attribute("tool_name", tool_name)
- if result.error_code == 0:
- last_message = input_messages[-1]
- last_message.context = result.content
+ )
+ result = await self.tool_runtime_api.invoke_tool(
+ tool_name="memory",
+ args=args,
+ )
+
+ yield AgentTurnResponseStreamChunk(
+ event=AgentTurnResponseEvent(
+ payload=AgentTurnResponseStepCompletePayload(
+ step_type=StepType.tool_execution.value,
+ step_id=step_id,
+ step_details=ToolExecutionStep(
+ step_id=step_id,
+ turn_id=turn_id,
+ tool_calls=[
+ ToolCall(
+ call_id="",
+ tool_name="memory",
+ arguments={},
+ )
+ ],
+ tool_responses=[
+ ToolResponse(
+ call_id="",
+ tool_name="memory",
+ content=result.content,
+ )
+ ],
+ ),
+ )
+ )
+ )
+ span.set_attribute(
+ "input", [m.model_dump_json() for m in input_messages]
+ )
+ span.set_attribute("output", result.content)
+ span.set_attribute("error_code", result.error_code)
+ span.set_attribute("error_message", result.error_message)
+ span.set_attribute("tool_name", "memory")
+ if result.error_code == 0:
+ last_message = input_messages[-1]
+ last_message.context = result.content
output_attachments = []
@@ -451,7 +463,11 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
- tools=await self._get_tools(),
+ tools=[
+ tool
+ for tool in tool_defs.values()
+ if tool.tool_name != "memory"
+ ],
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
@@ -654,44 +670,66 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
- async def _get_tools(self) -> List[ToolDefinition]:
- ret = []
- for tool in self.agent_config.client_tools:
- params = {}
- for param in tool.parameters:
- params[param.name] = ToolParamDefinition(
- param_type=param.parameter_type,
- description=param.description,
- required=param.required,
- default=param.default,
- )
- ret.append(
- ToolDefinition(
- tool_name=tool.name,
- description=tool.description,
- parameters=params,
- )
+ async def _get_tool_defs(
+ self, tools_for_turn: Optional[List[AgentTool]]
+ ) -> Dict[str, ToolDefinition]:
+ # Determine which tools to include
+ agent_config_tools = set(
+ tool.name if isinstance(tool, AgentToolWithArgs) else tool
+ for tool in self.agent_config.tools
+ )
+ tools_for_turn_set = (
+ agent_config_tools
+ if tools_for_turn is None
+ else {
+ tool.name if isinstance(tool, AgentToolWithArgs) else tool
+ for tool in tools_for_turn
+ }
+ )
+
+ ret = {}
+
+ for tool_def in self.agent_config.client_tools:
+ ret[tool_def.name] = ToolDefinition(
+ tool_name=tool_def.name,
+ description=tool_def.description,
+ parameters={
+ param.name: ToolParamDefinition(
+ param_type=param.parameter_type,
+ description=param.description,
+ required=param.required,
+ default=param.default,
+ )
+ for param in tool_def.parameters
+ },
)
- for tool_name in self.agent_config.tool_names:
- tool = await self.tool_groups_api.get_tool(tool_name)
- if tool.built_in_type:
- ret.append(ToolDefinition(tool_name=tool.built_in_type))
+
+ for tool_name in agent_config_tools:
+ if tool_name not in tools_for_turn_set:
continue
- params = {}
- for param in tool.parameters:
- params[param.name] = ToolParamDefinition(
- param_type=param.parameter_type,
- description=param.description,
- required=param.required,
- default=param.default,
- )
- ret.append(
- ToolDefinition(
- tool_name=tool.identifier,
- description=tool.description,
- parameters=params,
+
+ tool_def = await self.tool_groups_api.get_tool(tool_name)
+
+ if tool_def.built_in_type:
+ ret[tool_def.built_in_type] = ToolDefinition(
+ tool_name=tool_def.built_in_type
)
+ continue
+
+ ret[tool_def.identifier] = ToolDefinition(
+ tool_name=tool_def.identifier,
+ description=tool_def.description,
+ parameters={
+ param.name: ToolParamDefinition(
+ param_type=param.parameter_type,
+ description=param.description,
+ required=param.required,
+ default=param.default,
+ )
+ for param in tool_def.parameters
+ },
)
+
return ret
diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py
index 0515c9a5e..ab7f8878f 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agents.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agents.py
@@ -19,6 +19,7 @@ from llama_stack.apis.agents import (
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
+ AgentTool,
AgentTurnCreateRequest,
Session,
Turn,
@@ -145,6 +146,7 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
+ tools: Optional[List[AgentTool]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
@@ -152,6 +154,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id,
messages=messages,
stream=True,
+ tools=tools,
)
if stream:
return self._create_agent_turn_streaming(request)
diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py
index d492309cd..cad123696 100644
--- a/llama_stack/providers/inline/tool_runtime/memory/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py
@@ -54,14 +54,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return []
async def _retrieve_context(
- self, messages: List[Message], bank_ids: List[str]
+ self, message: Message, bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
- if len(messages) == 0:
- return None
-
- message = messages[-1] # only use the last message as input to the query
query = await generate_rag_query(
self.config.query_generator_config,
message,
@@ -113,10 +109,15 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
config = MemoryToolConfig()
if tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
-
+ if "memory_bank_id" in args:
+ bank_ids = [args["memory_bank_id"]]
+ else:
+ bank_ids = [
+ bank_config.bank_id for bank_config in config.memory_bank_configs
+ ]
context = await self._retrieve_context(
- args["input_messages"],
- [bank_config.bank_id for bank_config in config.memory_bank_configs],
+ args["query"],
+ bank_ids,
)
if context is None:
context = []
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index 327132b0a..3dad5ade4 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import CoreModelId
-
from llama_models.llama3.api.chat_format import ChatFormat
-
from llama_models.llama3.api.tokenizer import Tokenizer
-
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
@@ -53,7 +50,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
-
MODEL_ALIASES = [
build_model_alias(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py
index 44b0f8a2e..cb20e5890 100644
--- a/llama_stack/providers/tests/agents/test_agents.py
+++ b/llama_stack/providers/tests/agents/test_agents.py
@@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool(
agent_config = AgentConfig(
**{
**common_params,
- "tool_names": [tool_name],
+ "tools": [tool_name],
}
)
@@ -268,7 +268,7 @@ class TestAgents:
agent_config = AgentConfig(
**{
**common_params,
- "preprocessing_tools": ["memory"],
+ "tools": ["memory"],
"tool_choice": ToolChoice.auto,
}
)
diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py
index 1630ef34b..64c3c159f 100644
--- a/tests/client-sdk/agents/test_agents.py
+++ b/tests/client-sdk/agents/test_agents.py
@@ -9,7 +9,7 @@ from typing import Dict, List
from uuid import uuid4
import pytest
-from llama_stack_client.lib.agents.agent import Agent
+from llama_stack_client.lib.agents.agent import Agent, AugmentConfigWithMemoryTool
from llama_stack_client.lib.agents.client_tool import ClientTool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
@@ -151,11 +151,10 @@ def test_agent_simple(llama_stack_client, agent_config):
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
agent_config = {
**agent_config,
- "tool_names": [
+ "tools": [
"brave_search",
],
}
- print(f"Agent Config: {agent_config}")
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
@@ -181,7 +180,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = {
**agent_config,
- "tool_names": [
+ "tools": [
"code_interpreter",
],
}
@@ -209,7 +208,7 @@ def test_custom_tool(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
- "tool_names": ["brave_search"],
+ "tools": ["brave_search"],
"client_tools": [client_tool.get_tool_definition()],
"tool_prompt_format": "python_list",
}
@@ -252,8 +251,12 @@ def test_rag_agent(llama_stack_client, agent_config):
for i, url in enumerate(urls)
]
- agent = Agent.with_memory(llama_stack_client, agent_config)
- [agent.add_document(document) for document in documents]
+ memory_bank_id = AugmentConfigWithMemoryTool(agent_config, llama_stack_client)
+ agent = Agent(llama_stack_client, agent_config)
+ llama_stack_client.memory.insert(
+ bank_id=memory_bank_id,
+ documents=documents,
+ )
session_id = agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
@@ -271,8 +274,16 @@ def test_rag_agent(llama_stack_client, agent_config):
}
],
session_id=session_id,
+ tools=[
+ {
+ "name": "memory",
+ "args": {
+ "memory_bank_id": memory_bank_id,
+ },
+ }
+ ],
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
- assert "Tool:memory-tool" in logs_str
+ assert "Tool:memory" in logs_str