feat: ability to retrieve agents session, turn, step by ids (#1286)

# What does this PR do?

- Fix up rotten implementation for retrieving agent's Session, Turn,
Step with actual working implementation.

- Update `getting_started` notebook with retrieving by agent session_id.
https://github.com/meta-llama/llama-stack/blob/export_agent_dataset/docs/getting_started.ipynb

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

Test with script:
https://gist.github.com/yanxi0830/657cecee8f1f0e39d322963d9c0f598e

<img width="503" alt="image"
src="https://github.com/user-attachments/assets/5ea9bc33-83d1-40bc-98e1-b68393158387"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-02-27 09:45:14 -08:00 committed by GitHub
parent 0762c61402
commit fc5aff3ccf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 760 additions and 54 deletions

File diff suppressed because one or more lines are too long

View file

@ -61,7 +61,12 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools import (
RAGDocument,
ToolGroups,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,

View file

@ -194,17 +194,13 @@ class MetaReferenceAgentsImpl(Agents):
yield event yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") agent = await self.get_agent(agent_id)
turn = json.loads(turn) turn = await agent.storage.get_session_turn(session_id, turn_id)
turn = Turn(**turn)
return turn return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse: async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") turn = await self.get_agents_turn(agent_id, session_id, turn_id)
turn = json.loads(turn) for step in turn.steps:
turn = Turn(**turn)
steps = turn.steps
for step in steps:
if step.step_id == step_id: if step.step_id == step_id:
return AgentStepResponse(step=step) return AgentStepResponse(step=step)
raise ValueError(f"Provided step_id {step_id} could not be found") raise ValueError(f"Provided step_id {step_id} could not be found")
@ -215,20 +211,18 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str, session_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: Optional[List[str]] = None,
) -> Session: ) -> Session:
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}") agent = await self.get_agent(agent_id)
session = Session(**json.loads(session), turns=[]) session_info = await agent.storage.get_session_info(session_id)
turns = [] if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id)
if turn_ids: if turn_ids:
for turn_id in turn_ids: turns = [turn for turn in turns if turn.turn_id in turn_ids]
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
turns.append(turn)
return Session( return Session(
session_name=session.session_name, session_name=session_info.session_name,
session_id=session_id, session_id=session_id,
turns=turns if turns else [], turns=turns,
started_at=session.started_at, started_at=session_info.started_at,
) )
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: async def delete_agents_session(self, agent_id: str, session_id: str) -> None:

View file

@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel): class AgentSessionInfo(BaseModel):
session_id: str session_id: str
session_name: str session_name: str
# TODO: is this used anywhere?
vector_db_id: Optional[str] = None vector_db_id: Optional[str] = None
started_at: datetime started_at: datetime
@ -85,6 +86,14 @@ class AgentPersistence:
turns.sort(key=lambda x: (x.completed_at or datetime.min)) turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
)
if not value:
return None
return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set( await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",