From 761217dcc4d069d030ef5039a64925bc023cfd9b Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande Date: Fri, 18 Oct 2024 17:21:33 +0530 Subject: [PATCH] Added implementation for get_agents_step and get_agents_turn --- llama_stack/apis/agents/agents.py | 3 +- .../impls/meta_reference/agents/agents.py | 39 +++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index de710a94f..7ff276dae 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -443,11 +443,12 @@ class Agents(Protocol): self, agent_id: str, turn_id: str, + session_id: str ) -> Turn: ... @webmethod(route="/agents/step/get") async def get_agents_step( - self, agent_id: str, turn_id: str, step_id: str + self, agent_id: str, turn_id: str, step_id: str, session_id: str ) -> AgentStepResponse: ... @webmethod(route="/agents/session/create") diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 14db8ab20..21f5ba3af 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -138,13 +138,44 @@ class MetaReferenceAgentsImpl(Agents): async for event in agent.create_and_execute_turn(request): yield event - async def get_agents_turn(self, agent_id: str, turn_id: str) -> Turn: - raise NotImplementedError() + async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: + turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") + try: + turn = json.loads(turn) + except json.JSONDecodeError as e: + raise ValueError( + f"Could not JSON decode turn for {turn_id}" + ) from e + try: + turn = Turn(**turn) + except Exception as e: + raise ValueError( + f"Could not validate(?) Turns for {turn_id}" + ) from e + return turn async def get_agents_step( - self, agent_id: str, turn_id: str, step_id: str + self, agent_id: str, turn_id: str, session_id: str, step_id: str ) -> AgentStepResponse: - raise NotImplementedError() + turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") + try: + turn = json.loads(turn) + except json.JSONDecodeError as e: + raise ValueError( + f"Could not JSON decode turn for {turn_id}" + ) from e + try: + turn = Turn(**turn) + except Exception as e: + raise ValueError( + f"Could not validate(?) Turns for {turn_id}" + ) from e + steps = turn.steps + for step in steps: + if step.step_id == step_id: + return AgentStepResponse(step=step) + raise ValueError("Provided step_id could not be found") + async def get_agents_session( self,