Get the agents method also

This commit is contained in:
Ashwin Bharambe 2024-10-18 19:57:34 -07:00
parent 627edaf407
commit bcaf639dd6
4 changed files with 9 additions and 11 deletions

View file

@ -421,10 +421,8 @@ class Agents(Protocol):
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create")
def create_agent_turn(
async def create_agent_turn(
self,
agent_id: str,
session_id: str,

View file

@ -67,14 +67,14 @@ class AgentsClient(Agents):
response.raise_for_status()
return AgentSessionCreateResponse(**response.json())
def create_agent_turn(
async def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
return await self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
@ -126,7 +126,7 @@ async def _run_agent(
for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agent_turn(
iterator = await api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,

View file

@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id,
)
def create_agent_turn(
async def create_agent_turn(
self,
agent_id: str,
session_id: str,

View file

@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages):
)
turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments(
)
turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments(
)
turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search(
)
turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0