This commit is contained in:
Xi Yan 2025-02-20 17:38:21 -08:00
parent 01f90dfe0c
commit cd36a77e20
3 changed files with 69 additions and 26 deletions

View file

@ -297,6 +297,15 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
tool_config: Optional[ToolConfig] = None
@json_schema_type
class AgentTurnContinueRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: List[ToolResponseMessage]
stream: Optional[bool] = False
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
"""streamed agent turn completion response."""

View file

@ -23,6 +23,7 @@ from llama_stack.apis.agents import (
AgentConfig,
AgentToolGroup,
AgentToolGroupWithArgs,
AgentTurnContinueRequest,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@ -30,7 +31,6 @@ from llama_stack.apis.agents import (
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
Attachment,
@ -227,25 +227,51 @@ class ChatAgent(ShieldRunnerMixin):
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
)
yield chunk
async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerator:
with tracing.span("continue_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
# steps = []
# output_message = None
# async for chunk in self.run(
# session_id=request.session_id,
# turn_id=request.turn_id,
# input_messages=messages,
# sampling_params=self.agent_config.sampling_params,
# stream=request.stream,
# documents=request.documents,
# toolgroups_for_turn=request.toolgroups,
# ):
# if isinstance(chunk, CompletionMessage):
async def run(
self,
session_id: str,
@ -626,7 +652,11 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message]
else:
log.info(f"{str(message)}")
# 1. Start the tool execution step and progress
tool_call = message.tool_calls[0]
if tool_call.tool_name in client_tools:
yield message
return
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -636,8 +666,6 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -652,12 +680,6 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
yield message
return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value

View file

@ -20,6 +20,7 @@ from llama_stack.apis.agents import (
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnContinueRequest,
AgentTurnCreateRequest,
Document,
Session,
@ -177,7 +178,18 @@ class MetaReferenceAgentsImpl(Agents):
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> AsyncGenerator:
pass
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _continue_agent_turn_streaming(
self,
request: AgentTurnContinueRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.continue_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")