From bcaf639dd6bb1ccf4fa70a992001dacee2bde19d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Oct 2024 19:57:34 -0700 Subject: [PATCH] Get the agents method also --- llama_stack/apis/agents/agents.py | 4 +--- llama_stack/apis/agents/client.py | 6 +++--- .../providers/impls/meta_reference/agents/agents.py | 2 +- llama_stack/providers/tests/agents/test_agents.py | 8 ++++---- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index de710a94f..db0b1a269 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -421,10 +421,8 @@ class Agents(Protocol): agent_config: AgentConfig, ) -> AgentCreateResponse: ... - # This method is not `async def` because it can result in either an - # `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`. @webmethod(route="/agents/turn/create") - def create_agent_turn( + async def create_agent_turn( self, agent_id: str, session_id: str, diff --git a/llama_stack/apis/agents/client.py b/llama_stack/apis/agents/client.py index 32bc9abdd..b45447328 100644 --- a/llama_stack/apis/agents/client.py +++ b/llama_stack/apis/agents/client.py @@ -67,14 +67,14 @@ class AgentsClient(Agents): response.raise_for_status() return AgentSessionCreateResponse(**response.json()) - def create_agent_turn( + async def create_agent_turn( self, request: AgentTurnCreateRequest, ) -> AsyncGenerator: if request.stream: return self._stream_agent_turn(request) else: - return self._nonstream_agent_turn(request) + return await self._nonstream_agent_turn(request) async def _stream_agent_turn( self, request: AgentTurnCreateRequest @@ -126,7 +126,7 @@ async def _run_agent( for content in user_prompts: cprint(f"User> {content}", color="white", attrs=["bold"]) - iterator = api.create_agent_turn( + iterator = await api.create_agent_turn( AgentTurnCreateRequest( agent_id=create_response.agent_id, session_id=session_response.session_id, diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 5a209d0b7..8b3ece978 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents): session_id=session_id, ) - def create_agent_turn( + async def create_agent_turn( self, agent_id: str, session_id: str, diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 6774d3f1f..9c34c3a28 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages): ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0