mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Added implementation for get_agents_step and get_agents_turn
This commit is contained in:
parent
c2e8011175
commit
761217dcc4
2 changed files with 37 additions and 5 deletions
|
@ -443,11 +443,12 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
|
session_id: str
|
||||||
) -> Turn: ...
|
) -> Turn: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/step/get")
|
@webmethod(route="/agents/step/get")
|
||||||
async def get_agents_step(
|
async def get_agents_step(
|
||||||
self, agent_id: str, turn_id: str, step_id: str
|
self, agent_id: str, turn_id: str, step_id: str, session_id: str
|
||||||
) -> AgentStepResponse: ...
|
) -> AgentStepResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/session/create")
|
@webmethod(route="/agents/session/create")
|
||||||
|
|
|
@ -138,13 +138,44 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
async for event in agent.create_and_execute_turn(request):
|
async for event in agent.create_and_execute_turn(request):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
raise NotImplementedError()
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
|
try:
|
||||||
|
turn = json.loads(turn)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not JSON decode turn for {turn_id}"
|
||||||
|
) from e
|
||||||
|
try:
|
||||||
|
turn = Turn(**turn)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not validate(?) Turns for {turn_id}"
|
||||||
|
) from e
|
||||||
|
return turn
|
||||||
|
|
||||||
async def get_agents_step(
|
async def get_agents_step(
|
||||||
self, agent_id: str, turn_id: str, step_id: str
|
self, agent_id: str, turn_id: str, session_id: str, step_id: str
|
||||||
) -> AgentStepResponse:
|
) -> AgentStepResponse:
|
||||||
raise NotImplementedError()
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
|
try:
|
||||||
|
turn = json.loads(turn)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not JSON decode turn for {turn_id}"
|
||||||
|
) from e
|
||||||
|
try:
|
||||||
|
turn = Turn(**turn)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not validate(?) Turns for {turn_id}"
|
||||||
|
) from e
|
||||||
|
steps = turn.steps
|
||||||
|
for step in steps:
|
||||||
|
if step.step_id == step_id:
|
||||||
|
return AgentStepResponse(step=step)
|
||||||
|
raise ValueError("Provided step_id could not be found")
|
||||||
|
|
||||||
|
|
||||||
async def get_agents_session(
|
async def get_agents_session(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue