From 16d1f66f558290eb2201fc94335296ea97185742 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 2 Jan 2025 18:42:20 -0800 Subject: [PATCH] address feedback --- docs/resources/llama-stack-spec.html | 66 ++++- docs/resources/llama-stack-spec.yaml | 40 ++- llama_stack/apis/agents/agents.py | 20 +- .../agents/meta_reference/agent_instance.py | 254 ++++++++++-------- .../inline/agents/meta_reference/agents.py | 3 + .../inline/tool_runtime/memory/memory.py | 17 +- .../remote/inference/together/together.py | 4 - .../providers/tests/agents/test_agents.py | 4 +- tests/client-sdk/agents/test_agents.py | 27 +- 9 files changed, 286 insertions(+), 149 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index d116b1448..33ca52363 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3705,10 +3705,10 @@ "type": "string" } }, - "tool_names": { + "tools": { "type": "array", "items": { - "type": "string" + "$ref": "#/components/schemas/AgentTool" } }, "client_tools": { @@ -3717,12 +3717,6 @@ "$ref": "#/components/schemas/UserDefinedToolDef" } }, - "preprocessing_tools": { - "type": "array", - "items": { - "type": "string" - } - }, "tool_choice": { "$ref": "#/components/schemas/ToolChoice", "default": "auto" @@ -3753,6 +3747,51 @@ "enable_session_persistence" ] }, + "AgentTool": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "args": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "name", + "args" + ] + } + ] + }, "ToolParameter": { "type": "object", "properties": { @@ -3934,6 +3973,12 @@ }, "stream": { "type": "boolean" + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AgentTool" + } } }, "additionalProperties": false, @@ -7944,6 +7989,10 @@ "name": "AgentStepResponse", "description": "" }, + { + "name": "AgentTool", + "description": "" + }, { "name": "AgentTurnResponseEvent", "description": "Streamed agent execution response.\n\n" @@ -8691,6 +8740,7 @@ "AgentCreateResponse", "AgentSessionCreateResponse", "AgentStepResponse", + "AgentTool", "AgentTurnResponseEvent", "AgentTurnResponseStepCompletePayload", "AgentTurnResponseStepProgressPayload", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index c1097107e..4da311cf0 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -38,22 +38,18 @@ components: items: type: string type: array - preprocessing_tools: - items: - type: string - type: array sampling_params: $ref: '#/components/schemas/SamplingParams' tool_choice: $ref: '#/components/schemas/ToolChoice' default: auto - tool_names: - items: - type: string - type: array tool_prompt_format: $ref: '#/components/schemas/ToolPromptFormat' default: json + tools: + items: + $ref: '#/components/schemas/AgentTool' + type: array required: - max_infer_iters - model @@ -88,6 +84,27 @@ components: required: - step type: object + AgentTool: + oneOf: + - type: string + - additionalProperties: false + properties: + args: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + name: + type: string + required: + - name + - args + type: object AgentTurnResponseEvent: additionalProperties: false properties: @@ -611,6 +628,10 @@ components: type: string stream: type: boolean + tools: + items: + $ref: '#/components/schemas/AgentTool' + type: array required: - agent_id - session_id @@ -4726,6 +4747,8 @@ tags: - description: name: AgentStepResponse +- description: + name: AgentTool - description: 'Streamed agent execution response. @@ -5257,6 +5280,7 @@ x-tagGroups: - AgentCreateResponse - AgentSessionCreateResponse - AgentStepResponse + - AgentTool - AgentTurnResponseEvent - AgentTurnResponseStepCompletePayload - AgentTurnResponseStepProgressPayload diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 09184d09a..18bbcd95c 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -18,7 +18,7 @@ from typing import ( Union, ) -from llama_models.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated @@ -132,14 +132,27 @@ class Session(BaseModel): memory_bank: Optional[MemoryBank] = None +class AgentToolWithArgs(BaseModel): + name: str + args: Dict[str, Any] + + +AgentTool = register_schema( + Union[ + str, + AgentToolWithArgs, + ], + name="AgentTool", +) + + class AgentConfigCommon(BaseModel): sampling_params: Optional[SamplingParams] = SamplingParams() input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - tool_names: Optional[List[str]] = Field(default_factory=list) + tools: Optional[List[AgentTool]] = Field(default_factory=list) client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list) - preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json @@ -295,6 +308,7 @@ class Agents(Protocol): ] ], stream: Optional[bool] = False, + tools: Optional[List[AgentTool]] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/turn/get") diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b035ac098..700fa565c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -13,7 +13,7 @@ import secrets import string import uuid from datetime import datetime -from typing import AsyncGenerator, List, Optional +from typing import AsyncGenerator, Dict, List, Optional from urllib.parse import urlparse import httpx @@ -21,6 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe from llama_stack.apis.agents import ( AgentConfig, + AgentTool, + AgentToolWithArgs, AgentTurnCreateRequest, AgentTurnResponseEvent, AgentTurnResponseEventType, @@ -188,6 +190,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages=messages, sampling_params=self.agent_config.sampling_params, stream=request.stream, + tools_for_turn=request.tools, ): if isinstance(chunk, CompletionMessage): log.info( @@ -237,6 +240,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages: List[Message], sampling_params: SamplingParams, stream: bool = False, + tools_for_turn: Optional[List[AgentTool]] = None, ) -> AsyncGenerator: # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot @@ -253,7 +257,7 @@ class ChatAgent(ShieldRunnerMixin): yield res async for res in self._run( - session_id, turn_id, input_messages, sampling_params, stream + session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn ): if isinstance(res, bool): return @@ -348,82 +352,90 @@ class ChatAgent(ShieldRunnerMixin): input_messages: List[Message], sampling_params: SamplingParams, stream: bool = False, + tools_for_turn: Optional[List[AgentTool]] = None, ) -> AsyncGenerator: - if self.agent_config.preprocessing_tools: - with tracing.span("preprocessing_tools") as span: - for tool_name in self.agent_config.preprocessing_tools: - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - ) - ) - ) - args = dict( - session_id=session_id, - turn_id=turn_id, - input_messages=input_messages, - ) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - tool_call_delta=ToolCallDelta( - parse_status=ToolCallParseStatus.success, - content=ToolCall( - call_id="", tool_name=tool_name, arguments={} - ), - ), - ) - ) - ) - result = await self.tool_runtime_api.invoke_tool( - tool_name=tool_name, - args=args, - ) + tool_args = {} + if tools_for_turn: + for tool in tools_for_turn: + if isinstance(tool, AgentToolWithArgs): + tool_args[tool.name] = tool.args - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - step_details=ToolExecutionStep( - step_id=step_id, - turn_id=turn_id, - tool_calls=[ - ToolCall( - call_id="", - tool_name=tool_name, - arguments={}, - ) - ], - tool_responses=[ - ToolResponse( - call_id="", - tool_name=tool_name, - content=result.content, - ) - ], - ), - ) + tool_defs = await self._get_tool_defs(tools_for_turn) + if "memory" in tool_defs and len(input_messages) > 0: + with tracing.span("memory_tool") as span: + step_id = str(uuid.uuid4()) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, ) ) - span.set_attribute( - "input", [m.model_dump_json() for m in input_messages] + ) + extra_args = tool_args.get("memory", {}) + args = { + # Query memory with the last message's content + "query": input_messages[-1], + **extra_args, + } + serialized_args = tracing.serialize_value(args) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + tool_call_delta=ToolCallDelta( + parse_status=ToolCallParseStatus.success, + content=ToolCall( + call_id="", + tool_name="memory", + arguments=serialized_args, + ), + ), + ) ) - span.set_attribute("output", result.content) - span.set_attribute("error_code", result.error_code) - span.set_attribute("error_message", result.error_message) - if isinstance(tool_name, BuiltinTool): - span.set_attribute("tool_name", tool_name.value) - else: - span.set_attribute("tool_name", tool_name) - if result.error_code == 0: - last_message = input_messages[-1] - last_message.context = result.content + ) + result = await self.tool_runtime_api.invoke_tool( + tool_name="memory", + args=args, + ) + + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + step_details=ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[ + ToolCall( + call_id="", + tool_name="memory", + arguments={}, + ) + ], + tool_responses=[ + ToolResponse( + call_id="", + tool_name="memory", + content=result.content, + ) + ], + ), + ) + ) + ) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute("output", result.content) + span.set_attribute("error_code", result.error_code) + span.set_attribute("error_message", result.error_message) + span.set_attribute("tool_name", "memory") + if result.error_code == 0: + last_message = input_messages[-1] + last_message.context = result.content output_attachments = [] @@ -451,7 +463,11 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=await self._get_tools(), + tools=[ + tool + for tool in tool_defs.values() + if tool.tool_name != "memory" + ], tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=sampling_params, @@ -654,44 +670,66 @@ class ChatAgent(ShieldRunnerMixin): n_iter += 1 - async def _get_tools(self) -> List[ToolDefinition]: - ret = [] - for tool in self.agent_config.client_tools: - params = {} - for param in tool.parameters: - params[param.name] = ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - ret.append( - ToolDefinition( - tool_name=tool.name, - description=tool.description, - parameters=params, - ) + async def _get_tool_defs( + self, tools_for_turn: Optional[List[AgentTool]] + ) -> Dict[str, ToolDefinition]: + # Determine which tools to include + agent_config_tools = set( + tool.name if isinstance(tool, AgentToolWithArgs) else tool + for tool in self.agent_config.tools + ) + tools_for_turn_set = ( + agent_config_tools + if tools_for_turn is None + else { + tool.name if isinstance(tool, AgentToolWithArgs) else tool + for tool in tools_for_turn + } + ) + + ret = {} + + for tool_def in self.agent_config.client_tools: + ret[tool_def.name] = ToolDefinition( + tool_name=tool_def.name, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool_def.parameters + }, ) - for tool_name in self.agent_config.tool_names: - tool = await self.tool_groups_api.get_tool(tool_name) - if tool.built_in_type: - ret.append(ToolDefinition(tool_name=tool.built_in_type)) + + for tool_name in agent_config_tools: + if tool_name not in tools_for_turn_set: continue - params = {} - for param in tool.parameters: - params[param.name] = ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - ret.append( - ToolDefinition( - tool_name=tool.identifier, - description=tool.description, - parameters=params, + + tool_def = await self.tool_groups_api.get_tool(tool_name) + + if tool_def.built_in_type: + ret[tool_def.built_in_type] = ToolDefinition( + tool_name=tool_def.built_in_type ) + continue + + ret[tool_def.identifier] = ToolDefinition( + tool_name=tool_def.identifier, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool_def.parameters + }, ) + return ret diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 0515c9a5e..ab7f8878f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -19,6 +19,7 @@ from llama_stack.apis.agents import ( Agents, AgentSessionCreateResponse, AgentStepResponse, + AgentTool, AgentTurnCreateRequest, Session, Turn, @@ -145,6 +146,7 @@ class MetaReferenceAgentsImpl(Agents): ToolResponseMessage, ] ], + tools: Optional[List[AgentTool]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnCreateRequest( @@ -152,6 +154,7 @@ class MetaReferenceAgentsImpl(Agents): session_id=session_id, messages=messages, stream=True, + tools=tools, ) if stream: return self._create_agent_turn_streaming(request) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index d492309cd..cad123696 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -54,14 +54,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): return [] async def _retrieve_context( - self, messages: List[Message], bank_ids: List[str] + self, message: Message, bank_ids: List[str] ) -> Optional[List[InterleavedContent]]: if not bank_ids: return None - if len(messages) == 0: - return None - - message = messages[-1] # only use the last message as input to the query query = await generate_rag_query( self.config.query_generator_config, message, @@ -113,10 +109,15 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): config = MemoryToolConfig() if tool.metadata.get("config") is not None: config = MemoryToolConfig(**tool.metadata["config"]) - + if "memory_bank_id" in args: + bank_ids = [args["memory_bank_id"]] + else: + bank_ids = [ + bank_config.bank_id for bank_config in config.memory_bank_configs + ] context = await self._retrieve_context( - args["input_messages"], - [bank_config.bank_id for bank_config in config.memory_bank_configs], + args["query"], + bank_ids, ) if context is None: context = [] diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 327132b0a..3dad5ade4 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -7,11 +7,8 @@ from typing import AsyncGenerator, List, Optional, Union from llama_models.datatypes import CoreModelId - from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.tokenizer import Tokenizer - from together import Together from llama_stack.apis.common.content_types import InterleavedContent @@ -53,7 +50,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig - MODEL_ALIASES = [ build_model_alias( "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 44b0f8a2e..cb20e5890 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool( agent_config = AgentConfig( **{ **common_params, - "tool_names": [tool_name], + "tools": [tool_name], } ) @@ -268,7 +268,7 @@ class TestAgents: agent_config = AgentConfig( **{ **common_params, - "preprocessing_tools": ["memory"], + "tools": ["memory"], "tool_choice": ToolChoice.auto, } ) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 1630ef34b..64c3c159f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -9,7 +9,7 @@ from typing import Dict, List from uuid import uuid4 import pytest -from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.agent import Agent, AugmentConfigWithMemoryTool from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage @@ -151,11 +151,10 @@ def test_agent_simple(llama_stack_client, agent_config): def test_builtin_tool_brave_search(llama_stack_client, agent_config): agent_config = { **agent_config, - "tool_names": [ + "tools": [ "brave_search", ], } - print(f"Agent Config: {agent_config}") agent = Agent(llama_stack_client, agent_config) session_id = agent.create_session(f"test-session-{uuid4()}") @@ -181,7 +180,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, - "tool_names": [ + "tools": [ "code_interpreter", ], } @@ -209,7 +208,7 @@ def test_custom_tool(llama_stack_client, agent_config): agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "tool_names": ["brave_search"], + "tools": ["brave_search"], "client_tools": [client_tool.get_tool_definition()], "tool_prompt_format": "python_list", } @@ -252,8 +251,12 @@ def test_rag_agent(llama_stack_client, agent_config): for i, url in enumerate(urls) ] - agent = Agent.with_memory(llama_stack_client, agent_config) - [agent.add_document(document) for document in documents] + memory_bank_id = AugmentConfigWithMemoryTool(agent_config, llama_stack_client) + agent = Agent(llama_stack_client, agent_config) + llama_stack_client.memory.insert( + bank_id=memory_bank_id, + documents=documents, + ) session_id = agent.create_session(f"test-session-{uuid4()}") user_prompts = [ @@ -271,8 +274,16 @@ def test_rag_agent(llama_stack_client, agent_config): } ], session_id=session_id, + tools=[ + { + "name": "memory", + "args": { + "memory_bank_id": memory_bank_id, + }, + } + ], ) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - assert "Tool:memory-tool" in logs_str + assert "Tool:memory" in logs_str