From e8f7fa1ce164df11bdfb1d2df847d6ad6a8fb283 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Mar 2025 13:38:04 -0800 Subject: [PATCH] wip refactor --- .../agents/meta_reference/agent_instance.py | 147 +++++++++--------- 1 file changed, 75 insertions(+), 72 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 886a36024..e4409d6a7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -12,7 +12,7 @@ import secrets import string import uuid from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse import httpx @@ -31,7 +31,6 @@ from llama_stack.apis.agents import ( AgentTurnResponseStreamChunk, AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, - AgentTurnResponseTurnStartPayload, AgentTurnResumeRequest, Attachment, Document, @@ -184,84 +183,88 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("session_id", request.session_id) span.set_attribute("agent_id", self.agent_id) span.set_attribute("request", request.model_dump_json()) - assert request.stream is True, "Non-streaming not supported" - - session_info = await self.storage.get_session_info(request.session_id) - if session_info is None: - raise ValueError(f"Session {request.session_id} not found") - - turns = await self.storage.get_session_turns(request.session_id) - messages = await self.get_messages_from_turns(turns) - messages.extend(request.messages) - turn_id = str(uuid.uuid4()) span.set_attribute("turn_id", turn_id) - start_time = datetime.now().astimezone().isoformat() - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnStartPayload( - turn_id=turn_id, - ) - ) - ) - - steps = [] - output_message = None - async for chunk in self.run( - session_id=request.session_id, - turn_id=turn_id, - input_messages=messages, - sampling_params=self.agent_config.sampling_params, - stream=request.stream, - documents=request.documents, - toolgroups_for_turn=request.toolgroups, - ): - if isinstance(chunk, CompletionMessage): - logcat.info( - "agents", - f"returning result from the agent turn: {chunk}", - ) - output_message = chunk - continue - - assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" - event = chunk.event - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: - steps.append(event.payload.step_details) - + assert request.stream is True, "Non-streaming not supported" + async for chunk in self._run_turn(request, turn_id): yield chunk - assert output_message is not None + async def _run_turn( + self, + request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest], + turn_id: Optional[str] = None, + ) -> AsyncGenerator: + is_resume = isinstance(request, AgentTurnResumeRequest) + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") - turn = Turn( - turn_id=turn_id, - session_id=request.session_id, - input_messages=request.messages, - output_message=output_message, - started_at=start_time, - completed_at=datetime.now().astimezone().isoformat(), - steps=steps, - ) - await self.storage.add_turn_to_session(request.session_id, turn) - if output_message.tool_calls: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnAwaitingInputPayload( - turn=turn, - ) - ) - ) - else: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, - ) - ) - ) + turns = await self.storage.get_session_turns(request.session_id) + if is_resume and len(turns) == 0: + raise ValueError("No turns found for session") + + messages = await self.get_messages_from_turns(turns) + if is_resume: + messages.extend(request.tool_responses) + turn_id = request.turn_id + start_time = turns[-1].started_at + else: + messages.extend(request.messages) + start_time = datetime.now().astimezone().isoformat() + + steps = [] + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=turn_id, + input_messages=messages, + sampling_params=self.agent_config.sampling_params, + stream=request.stream, + documents=request.documents if not is_resume else None, + toolgroups_for_turn=request.toolgroups if not is_resume else None, + ): + if isinstance(chunk, CompletionMessage): + output_message = chunk + continue + + assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + event = chunk.event + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + steps.append(event.payload.step_details) yield chunk + assert output_message is not None + + turn = Turn( + turn_id=turn_id, + session_id=request.session_id, + input_messages=request.messages, + output_message=output_message, + started_at=start_time, + completed_at=datetime.now().astimezone().isoformat(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + + yield chunk + async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: with tracing.span("resume_turn") as span: span.set_attribute("agent_id", self.agent_id)