Get the agents method also

This commit is contained in:
Ashwin Bharambe 2024-10-18 19:57:34 -07:00
parent 627edaf407
commit bcaf639dd6
4 changed files with 9 additions and 11 deletions

View file

@ -421,10 +421,8 @@ class Agents(Protocol):
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgentCreateResponse: ... ) -> AgentCreateResponse: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create") @webmethod(route="/agents/turn/create")
def create_agent_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,

View file

@ -67,14 +67,14 @@ class AgentsClient(Agents):
response.raise_for_status() response.raise_for_status()
return AgentSessionCreateResponse(**response.json()) return AgentSessionCreateResponse(**response.json())
def create_agent_turn( async def create_agent_turn(
self, self,
request: AgentTurnCreateRequest, request: AgentTurnCreateRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
if request.stream: if request.stream:
return self._stream_agent_turn(request) return self._stream_agent_turn(request)
else: else:
return self._nonstream_agent_turn(request) return await self._nonstream_agent_turn(request)
async def _stream_agent_turn( async def _stream_agent_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
@ -126,7 +126,7 @@ async def _run_agent(
for content in user_prompts: for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"]) cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agent_turn( iterator = await api.create_agent_turn(
AgentTurnCreateRequest( AgentTurnCreateRequest(
agent_id=create_response.agent_id, agent_id=create_response.agent_id,
session_id=session_response.session_id, session_id=session_response.session_id,

View file

@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id, session_id=session_id,
) )
def create_agent_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,

View file

@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages):
) )
turn_response = [ turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0
@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments(
) )
turn_response = [ turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0
@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments(
) )
turn_response = [ turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0
@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search(
) )
turn_response = [ turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0