diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index de710a94f..db0b1a269 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -421,10 +421,8 @@ class Agents(Protocol): agent_config: AgentConfig, ) -> AgentCreateResponse: ... - # This method is not `async def` because it can result in either an - # `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`. @webmethod(route="/agents/turn/create") - def create_agent_turn( + async def create_agent_turn( self, agent_id: str, session_id: str, diff --git a/llama_stack/apis/agents/client.py b/llama_stack/apis/agents/client.py index 32bc9abdd..b45447328 100644 --- a/llama_stack/apis/agents/client.py +++ b/llama_stack/apis/agents/client.py @@ -67,14 +67,14 @@ class AgentsClient(Agents): response.raise_for_status() return AgentSessionCreateResponse(**response.json()) - def create_agent_turn( + async def create_agent_turn( self, request: AgentTurnCreateRequest, ) -> AsyncGenerator: if request.stream: return self._stream_agent_turn(request) else: - return self._nonstream_agent_turn(request) + return await self._nonstream_agent_turn(request) async def _stream_agent_turn( self, request: AgentTurnCreateRequest @@ -126,7 +126,7 @@ async def _run_agent( for content in user_prompts: cprint(f"User> {content}", color="white", attrs=["bold"]) - iterator = api.create_agent_turn( + iterator = await api.create_agent_turn( AgentTurnCreateRequest( agent_id=create_response.agent_id, session_id=session_response.session_id, diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 5a209d0b7..8b3ece978 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents): session_id=session_id, ) - def create_agent_turn( + async def create_agent_turn( self, agent_id: str, session_id: str, diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 6774d3f1f..9c34c3a28 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages): ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0