mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Added implementations for get_agents_session, delete_agents_session and delete_agents (#267)
This commit is contained in:
parent
b81a3bd46a
commit
8a01b9e40c
2 changed files with 41 additions and 11 deletions
|
@ -438,14 +438,12 @@ class Agents(Protocol):
|
|||
|
||||
@webmethod(route="/agents/turn/get")
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
turn_id: str,
|
||||
self, agent_id: str, session_id: str, turn_id: str
|
||||
) -> Turn: ...
|
||||
|
||||
@webmethod(route="/agents/step/get")
|
||||
async def get_agents_step(
|
||||
self, agent_id: str, turn_id: str, step_id: str
|
||||
self, agent_id: str, session_id: str, turn_id: str, step_id: str
|
||||
) -> AgentStepResponse: ...
|
||||
|
||||
@webmethod(route="/agents/session/create")
|
||||
|
|
|
@ -138,13 +138,29 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, turn_id: str) -> Turn:
|
||||
raise NotImplementedError()
|
||||
async def get_agents_turn(
|
||||
self, agent_id: str, session_id: str, turn_id: str
|
||||
) -> Turn:
|
||||
turn = await self.persistence_store.get(
|
||||
f"session:{agent_id}:{session_id}:{turn_id}"
|
||||
)
|
||||
turn = json.loads(turn)
|
||||
turn = Turn(**turn)
|
||||
return turn
|
||||
|
||||
async def get_agents_step(
|
||||
self, agent_id: str, turn_id: str, step_id: str
|
||||
self, agent_id: str, session_id: str, turn_id: str, step_id: str
|
||||
) -> AgentStepResponse:
|
||||
raise NotImplementedError()
|
||||
turn = await self.persistence_store.get(
|
||||
f"session:{agent_id}:{session_id}:{turn_id}"
|
||||
)
|
||||
turn = json.loads(turn)
|
||||
turn = Turn(**turn)
|
||||
steps = turn.steps
|
||||
for step in steps:
|
||||
if step.step_id == step_id:
|
||||
return AgentStepResponse(step=step)
|
||||
raise ValueError(f"Provided step_id {step_id} could not be found")
|
||||
|
||||
async def get_agents_session(
|
||||
self,
|
||||
|
@ -152,10 +168,26 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
raise NotImplementedError()
|
||||
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
|
||||
session = Session(**json.loads(session))
|
||||
turns = []
|
||||
if turn_ids:
|
||||
for turn_id in turn_ids:
|
||||
turn = await self.persistence_store.get(
|
||||
f"session:{agent_id}:{session_id}:{turn_id}"
|
||||
)
|
||||
turn = json.loads(turn)
|
||||
turn = Turn(**turn)
|
||||
turns.append(turn)
|
||||
return Session(
|
||||
session_name=session.session_name,
|
||||
session_id=session_id,
|
||||
turns=turns if turns else [],
|
||||
started_at=session.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
raise NotImplementedError()
|
||||
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
|
||||
|
||||
async def delete_agents(self, agent_id: str) -> None:
|
||||
raise NotImplementedError()
|
||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue