address feedback

This commit is contained in:
Dinesh Yeduguru 2025-01-02 18:42:20 -08:00
parent ee542a7373
commit 16d1f66f55
9 changed files with 286 additions and 149 deletions

View file

@ -3705,10 +3705,10 @@
"type": "string"
}
},
"tool_names": {
"tools": {
"type": "array",
"items": {
"type": "string"
"$ref": "#/components/schemas/AgentTool"
}
},
"client_tools": {
@ -3717,12 +3717,6 @@
"$ref": "#/components/schemas/UserDefinedToolDef"
}
},
"preprocessing_tools": {
"type": "array",
"items": {
"type": "string"
}
},
"tool_choice": {
"$ref": "#/components/schemas/ToolChoice",
"default": "auto"
@ -3753,6 +3747,51 @@
"enable_session_persistence"
]
},
"AgentTool": {
"oneOf": [
{
"type": "string"
},
{
"type": "object",
"properties": {
"name": {
"type": "string"
},
"args": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"name",
"args"
]
}
]
},
"ToolParameter": {
"type": "object",
"properties": {
@ -3934,6 +3973,12 @@
},
"stream": {
"type": "boolean"
},
"tools": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AgentTool"
}
}
},
"additionalProperties": false,
@ -7944,6 +7989,10 @@
"name": "AgentStepResponse",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/AgentStepResponse\" />"
},
{
"name": "AgentTool",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/AgentTool\" />"
},
{
"name": "AgentTurnResponseEvent",
"description": "Streamed agent execution response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/AgentTurnResponseEvent\" />"
@ -8691,6 +8740,7 @@
"AgentCreateResponse",
"AgentSessionCreateResponse",
"AgentStepResponse",
"AgentTool",
"AgentTurnResponseEvent",
"AgentTurnResponseStepCompletePayload",
"AgentTurnResponseStepProgressPayload",

View file

@ -38,22 +38,18 @@ components:
items:
type: string
type: array
preprocessing_tools:
items:
type: string
type: array
sampling_params:
$ref: '#/components/schemas/SamplingParams'
tool_choice:
$ref: '#/components/schemas/ToolChoice'
default: auto
tool_names:
items:
type: string
type: array
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
tools:
items:
$ref: '#/components/schemas/AgentTool'
type: array
required:
- max_infer_iters
- model
@ -88,6 +84,27 @@ components:
required:
- step
type: object
AgentTool:
oneOf:
- type: string
- additionalProperties: false
properties:
args:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
name:
type: string
required:
- name
- args
type: object
AgentTurnResponseEvent:
additionalProperties: false
properties:
@ -611,6 +628,10 @@ components:
type: string
stream:
type: boolean
tools:
items:
$ref: '#/components/schemas/AgentTool'
type: array
required:
- agent_id
- session_id
@ -4726,6 +4747,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentStepResponse"
/>
name: AgentStepResponse
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentTool" />
name: AgentTool
- description: 'Streamed agent execution response.
@ -5257,6 +5280,7 @@ x-tagGroups:
- AgentCreateResponse
- AgentSessionCreateResponse
- AgentStepResponse
- AgentTool
- AgentTurnResponseEvent
- AgentTurnResponseStepCompletePayload
- AgentTurnResponseStepProgressPayload

View file

@ -18,7 +18,7 @@ from typing import (
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
@ -132,14 +132,27 @@ class Session(BaseModel):
memory_bank: Optional[MemoryBank] = None
class AgentToolWithArgs(BaseModel):
name: str
args: Dict[str, Any]
AgentTool = register_schema(
Union[
str,
AgentToolWithArgs,
],
name="AgentTool",
)
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tool_names: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentTool]] = Field(default_factory=list)
client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -295,6 +308,7 @@ class Agents(Protocol):
]
],
stream: Optional[bool] = False,
tools: Optional[List[AgentTool]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")

View file

@ -13,7 +13,7 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Optional
from typing import AsyncGenerator, Dict, List, Optional
from urllib.parse import urlparse
import httpx
@ -21,6 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentToolWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@ -188,6 +190,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
tools_for_turn=request.tools,
):
if isinstance(chunk, CompletionMessage):
log.info(
@ -237,6 +240,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
sampling_params: SamplingParams,
stream: bool = False,
tools_for_turn: Optional[List[AgentTool]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -253,7 +257,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
session_id, turn_id, input_messages, sampling_params, stream
session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn
):
if isinstance(res, bool):
return
@ -348,82 +352,90 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
sampling_params: SamplingParams,
stream: bool = False,
tools_for_turn: Optional[List[AgentTool]] = None,
) -> AsyncGenerator:
if self.agent_config.preprocessing_tools:
with tracing.span("preprocessing_tools") as span:
for tool_name in self.agent_config.preprocessing_tools:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
args = dict(
session_id=session_id,
turn_id=turn_id,
input_messages=input_messages,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call_delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
content=ToolCall(
call_id="", tool_name=tool_name, arguments={}
),
),
)
)
)
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name,
args=args,
)
tool_args = {}
if tools_for_turn:
for tool in tools_for_turn:
if isinstance(tool, AgentToolWithArgs):
tool_args[tool.name] = tool.args
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name=tool_name,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name=tool_name,
content=result.content,
)
],
),
)
tool_defs = await self._get_tool_defs(tools_for_turn)
if "memory" in tool_defs and len(input_messages) > 0:
with tracing.span("memory_tool") as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
extra_args = tool_args.get("memory", {})
args = {
# Query memory with the last message's content
"query": input_messages[-1],
**extra_args,
}
serialized_args = tracing.serialize_value(args)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call_delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
content=ToolCall(
call_id="",
tool_name="memory",
arguments=serialized_args,
),
),
)
)
span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
if isinstance(tool_name, BuiltinTool):
span.set_attribute("tool_name", tool_name.value)
else:
span.set_attribute("tool_name", tool_name)
if result.error_code == 0:
last_message = input_messages[-1]
last_message.context = result.content
)
result = await self.tool_runtime_api.invoke_tool(
tool_name="memory",
args=args,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name="memory",
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name="memory",
content=result.content,
)
],
),
)
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", "memory")
if result.error_code == 0:
last_message = input_messages[-1]
last_message.context = result.content
output_attachments = []
@ -451,7 +463,11 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=await self._get_tools(),
tools=[
tool
for tool in tool_defs.values()
if tool.tool_name != "memory"
],
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
@ -654,44 +670,66 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _get_tools(self) -> List[ToolDefinition]:
ret = []
for tool in self.agent_config.client_tools:
params = {}
for param in tool.parameters:
params[param.name] = ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
ret.append(
ToolDefinition(
tool_name=tool.name,
description=tool.description,
parameters=params,
)
async def _get_tool_defs(
self, tools_for_turn: Optional[List[AgentTool]]
) -> Dict[str, ToolDefinition]:
# Determine which tools to include
agent_config_tools = set(
tool.name if isinstance(tool, AgentToolWithArgs) else tool
for tool in self.agent_config.tools
)
tools_for_turn_set = (
agent_config_tools
if tools_for_turn is None
else {
tool.name if isinstance(tool, AgentToolWithArgs) else tool
for tool in tools_for_turn
}
)
ret = {}
for tool_def in self.agent_config.client_tools:
ret[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
for tool_name in self.agent_config.tool_names:
tool = await self.tool_groups_api.get_tool(tool_name)
if tool.built_in_type:
ret.append(ToolDefinition(tool_name=tool.built_in_type))
for tool_name in agent_config_tools:
if tool_name not in tools_for_turn_set:
continue
params = {}
for param in tool.parameters:
params[param.name] = ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
ret.append(
ToolDefinition(
tool_name=tool.identifier,
description=tool.description,
parameters=params,
tool_def = await self.tool_groups_api.get_tool(tool_name)
if tool_def.built_in_type:
ret[tool_def.built_in_type] = ToolDefinition(
tool_name=tool_def.built_in_type
)
continue
ret[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
return ret

View file

@ -19,6 +19,7 @@ from llama_stack.apis.agents import (
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentTool,
AgentTurnCreateRequest,
Session,
Turn,
@ -145,6 +146,7 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
tools: Optional[List[AgentTool]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
@ -152,6 +154,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id,
messages=messages,
stream=True,
tools=tools,
)
if stream:
return self._create_agent_turn_streaming(request)

View file

@ -54,14 +54,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return []
async def _retrieve_context(
self, messages: List[Message], bank_ids: List[str]
self, message: Message, bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
if len(messages) == 0:
return None
message = messages[-1] # only use the last message as input to the query
query = await generate_rag_query(
self.config.query_generator_config,
message,
@ -113,10 +109,15 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
config = MemoryToolConfig()
if tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_id" in args:
bank_ids = [args["memory_bank_id"]]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
context = await self._retrieve_context(
args["input_messages"],
[bank_config.bank_id for bank_config in config.memory_bank_configs],
args["query"],
bank_ids,
)
if context is None:
context = []

View file

@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
@ -53,7 +50,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
MODEL_ALIASES = [
build_model_alias(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",

View file

@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool(
agent_config = AgentConfig(
**{
**common_params,
"tool_names": [tool_name],
"tools": [tool_name],
}
)
@ -268,7 +268,7 @@ class TestAgents:
agent_config = AgentConfig(
**{
**common_params,
"preprocessing_tools": ["memory"],
"tools": ["memory"],
"tool_choice": ToolChoice.auto,
}
)

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from uuid import uuid4
import pytest
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.agent import Agent, AugmentConfigWithMemoryTool
from llama_stack_client.lib.agents.client_tool import ClientTool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
@ -151,11 +151,10 @@ def test_agent_simple(llama_stack_client, agent_config):
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tool_names": [
"tools": [
"brave_search",
],
}
print(f"Agent Config: {agent_config}")
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
@ -181,7 +180,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tool_names": [
"tools": [
"code_interpreter",
],
}
@ -209,7 +208,7 @@ def test_custom_tool(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
"tool_names": ["brave_search"],
"tools": ["brave_search"],
"client_tools": [client_tool.get_tool_definition()],
"tool_prompt_format": "python_list",
}
@ -252,8 +251,12 @@ def test_rag_agent(llama_stack_client, agent_config):
for i, url in enumerate(urls)
]
agent = Agent.with_memory(llama_stack_client, agent_config)
[agent.add_document(document) for document in documents]
memory_bank_id = AugmentConfigWithMemoryTool(agent_config, llama_stack_client)
agent = Agent(llama_stack_client, agent_config)
llama_stack_client.memory.insert(
bank_id=memory_bank_id,
documents=documents,
)
session_id = agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
@ -271,8 +274,16 @@ def test_rag_agent(llama_stack_client, agent_config):
}
],
session_id=session_id,
tools=[
{
"name": "memory",
"args": {
"memory_bank_id": memory_bank_id,
},
}
],
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:memory-tool" in logs_str
assert "Tool:memory" in logs_str