diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index a83538b35..adf4313d7 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -297,6 +297,15 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): tool_config: Optional[ToolConfig] = None +@json_schema_type +class AgentTurnContinueRequest(BaseModel): + agent_id: str + session_id: str + turn_id: str + tool_responses: List[ToolResponseMessage] + stream: Optional[bool] = False + + @json_schema_type class AgentTurnResponseStreamChunk(BaseModel): """streamed agent turn completion response.""" diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8da3f3a14..2ae71ded6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -23,6 +23,7 @@ from llama_stack.apis.agents import ( AgentConfig, AgentToolGroup, AgentToolGroupWithArgs, + AgentTurnContinueRequest, AgentTurnCreateRequest, AgentTurnResponseEvent, AgentTurnResponseEventType, @@ -30,7 +31,6 @@ from llama_stack.apis.agents import ( AgentTurnResponseStepProgressPayload, AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, - AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, Attachment, @@ -227,25 +227,51 @@ class ChatAgent(ShieldRunnerMixin): ) await self.storage.add_turn_to_session(request.session_id, turn) - if output_message.tool_calls: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnAwaitingInputPayload( - turn=turn, - ) + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, ) ) - else: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, - ) - ) - ) - + ) yield chunk + async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerator: + with tracing.span("continue_turn") as span: + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("session_id", request.session_id) + span.set_attribute("turn_id", request.turn_id) + span.set_attribute("request", request.model_dump_json()) + assert request.stream is True, "Non-streaming not supported" + + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") + + turns = await self.storage.get_session_turns(request.session_id) + + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) + + for i, turn in enumerate(turns): + messages.extend(self.turn_to_messages(turn)) + + messages.extend(request.messages) + + # steps = [] + # output_message = None + # async for chunk in self.run( + # session_id=request.session_id, + # turn_id=request.turn_id, + # input_messages=messages, + # sampling_params=self.agent_config.sampling_params, + # stream=request.stream, + # documents=request.documents, + # toolgroups_for_turn=request.toolgroups, + # ): + # if isinstance(chunk, CompletionMessage): + async def run( self, session_id: str, @@ -626,7 +652,11 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message] else: log.info(f"{str(message)}") - # 1. Start the tool execution step and progress + tool_call = message.tool_calls[0] + if tool_call.tool_name in client_tools: + yield message + return + step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -636,8 +666,6 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - - tool_call = message.tool_calls[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -652,12 +680,6 @@ class ChatAgent(ShieldRunnerMixin): ) ) - # If tool is a client tool, yield CompletionMessage and return - if tool_call.tool_name in client_tools: - yield message - return - - # If tool is a builtin server tool, execute it tool_name = tool_call.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index dfbc41262..bdde89016 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -20,6 +20,7 @@ from llama_stack.apis.agents import ( AgentSessionCreateResponse, AgentStepResponse, AgentToolGroup, + AgentTurnContinueRequest, AgentTurnCreateRequest, Document, Session, @@ -177,7 +178,18 @@ class MetaReferenceAgentsImpl(Agents): tool_responses: List[ToolResponseMessage], stream: Optional[bool] = False, ) -> AsyncGenerator: - pass + if stream: + return self._continue_agent_turn_streaming(request) + else: + raise NotImplementedError("Non-streaming agent turns not yet implemented") + + async def _continue_agent_turn_streaming( + self, + request: AgentTurnContinueRequest, + ) -> AsyncGenerator: + agent = await self.get_agent(request.agent_id) + async for event in agent.continue_turn(request): + yield event async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")