mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: ability to retrieve agents session, turn, step by ids (#1286)
# What does this PR do? - Fix up rotten implementation for retrieving agent's Session, Turn, Step with actual working implementation. - Update `getting_started` notebook with retrieving by agent session_id. https://github.com/meta-llama/llama-stack/blob/export_agent_dataset/docs/getting_started.ipynb [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Test with script: https://gist.github.com/yanxi0830/657cecee8f1f0e39d322963d9c0f598e <img width="503" alt="image" src="https://github.com/user-attachments/assets/5ea9bc33-83d1-40bc-98e1-b68393158387" /> [//]: # (## Documentation)
This commit is contained in:
parent
0762c61402
commit
fc5aff3ccf
4 changed files with 760 additions and 54 deletions
|
@ -194,17 +194,13 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
turn = json.loads(turn)
|
||||
turn = Turn(**turn)
|
||||
agent = await self.get_agent(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
return turn
|
||||
|
||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
turn = json.loads(turn)
|
||||
turn = Turn(**turn)
|
||||
steps = turn.steps
|
||||
for step in steps:
|
||||
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
|
||||
for step in turn.steps:
|
||||
if step.step_id == step_id:
|
||||
return AgentStepResponse(step=step)
|
||||
raise ValueError(f"Provided step_id {step_id} could not be found")
|
||||
|
@ -215,20 +211,18 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
|
||||
session = Session(**json.loads(session), turns=[])
|
||||
turns = []
|
||||
agent = await self.get_agent(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
turns = await agent.storage.get_session_turns(session_id)
|
||||
if turn_ids:
|
||||
for turn_id in turn_ids:
|
||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
turn = json.loads(turn)
|
||||
turn = Turn(**turn)
|
||||
turns.append(turn)
|
||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||
return Session(
|
||||
session_name=session.session_name,
|
||||
session_name=session_info.session_name,
|
||||
session_id=session_id,
|
||||
turns=turns if turns else [],
|
||||
started_at=session.started_at,
|
||||
turns=turns,
|
||||
started_at=session_info.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue