fix agent run

This commit is contained in:
Xi Yan 2025-03-11 10:59:07 -07:00
parent 83a2c78615
commit c010d25dc0

View file

@ -12,6 +12,7 @@ import uuid
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
Agent,
AgentConfig, AgentConfig,
AgentCreateResponse, AgentCreateResponse,
Agents, Agents,
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest, AgentTurnCreateRequest,
AgentTurnResumeRequest, AgentTurnResumeRequest,
Document, Document,
ListAgentSessionsResponse,
ListAgentsResponse,
Session, Session,
Turn, Turn,
) )
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id=agent_id, agent_id=agent_id,
) )
async def get_agent(self, agent_id: str) -> ChatAgent: async def get_chat_agent(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get( agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}", key=f"agent:{agent_id}",
) )
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str, agent_id: str,
session_name: str, session_name: str,
) -> AgentSessionCreateResponse: ) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id) agent = await self.get_chat_agent(agent_id)
session_id = await agent.create_session(session_name) session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse( return AgentSessionCreateResponse(
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
request: AgentTurnCreateRequest, request: AgentTurnCreateRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id) agent = await self.get_chat_agent(request.agent_id)
async for event in agent.create_and_execute_turn(request): async for event in agent.create_and_execute_turn(request):
yield event yield event
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
request: AgentTurnResumeRequest, request: AgentTurnResumeRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id) agent = await self.get_chat_agent(request.agent_id)
async for event in agent.resume_turn(request): async for event in agent.resume_turn(request):
yield event yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self.get_agent(agent_id) agent = await self.get_chat_agent(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id) turn = await agent.storage.get_session_turn(session_id, turn_id)
return turn return turn
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str, session_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: Optional[List[str]] = None,
) -> Session: ) -> Session:
agent = await self.get_agent(agent_id) agent = await self.get_chat_agent(agent_id)
session_info = await agent.storage.get_session_info(session_id) session_info = await agent.storage.get_session_info(session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found")
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_agents(self) -> ListAgentsResponse:
pass
async def get_agent(self, agent_id: str) -> Agent:
pass
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
pass