From 7dae81cb68cb90d2c1550e322292aa9899f28cd0 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 12:57:18 -0800 Subject: [PATCH] tmp --- llama_stack/apis/agents/agents.py | 6 +++--- .../agents/meta_reference/agent_instance.py | 19 +++++++++++++++++-- tests/client-sdk/agents/test_agents.py | 12 ++++++++++-- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 2f374b638..382a67a57 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -344,15 +344,15 @@ class Agents(Protocol): ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod( - route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/submit_tool_response_messages", + route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/tool_responses", method="POST", ) - async def submit_tool_response_messages( + async def submit_tool_responses( self, agent_id: str, session_id: str, turn_id: str, - tool_response_messages: List[ToolResponseMessage], + tool_responses: Dict[str, ToolResponseMessage], ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod( diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1c21df57f..779dcf74d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -31,6 +31,7 @@ from llama_stack.apis.agents import ( AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnPendingPayload, AgentTurnResponseTurnStartPayload, Attachment, Document, @@ -62,7 +63,11 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO -from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + ToolCall, + ToolParamDefinition, +) from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing @@ -222,6 +227,15 @@ class ChatAgent(ShieldRunnerMixin): ) await self.storage.add_turn_to_session(request.session_id, turn) + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnPendingPayload( + turn=turn, + ) + ) + ) + else: chunk = AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseTurnCompletePayload( @@ -229,7 +243,8 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - yield chunk + + yield chunk async def run( self, diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5380d357..6b8caec25 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.tool_def_param import Parameter -from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig -from llama_stack.apis.agents.agents import ToolChoice +from llama_stack.apis.agents.agents import ( + AgentConfig as Server__AgentConfig, +) +from llama_stack.apis.agents.agents import ( + ToolChoice, +) class TestClientTool(ClientTool): @@ -314,6 +318,10 @@ def test_custom_tool(llama_stack_client, agent_config): ], session_id=session_id, ) + from rich.pretty import pprint + + for x in response: + pprint(x) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs)