diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 0a5d93d80..8e8b8804f 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -2319,7 +2319,7 @@ "post": { "responses": { "200": { - "description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.", + "description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.", "content": { "application/json": { "schema": { @@ -2337,11 +2337,12 @@ "tags": [ "Agents" ], - "description": "", + "description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.", "parameters": [ { "name": "agent_id", "in": "path", + "description": "The ID of the agent to resume.", "required": true, "schema": { "type": "string" @@ -2350,6 +2351,7 @@ { "name": "session_id", "in": "path", + "description": "The ID of the session to resume.", "required": true, "schema": { "type": "string" @@ -2358,6 +2360,7 @@ { "name": "turn_id", "in": "path", + "description": "The ID of the turn to resume.", "required": true, "schema": { "type": "string" @@ -4287,6 +4290,9 @@ }, "tool_config": { "$ref": "#/components/schemas/ToolConfig" + }, + "allow_turn_resume": { + "type": "boolean" } }, "additionalProperties": false, @@ -8106,10 +8112,12 @@ "type": "array", "items": { "$ref": "#/components/schemas/ToolResponseMessage" - } + }, + "description": "The tool call responses to resume the turn with." }, "stream": { - "type": "boolean" + "type": "boolean", + "description": "Whether to stream the response." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c05eef95e..c14ae3d3a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1406,8 +1406,8 @@ paths: responses: '200': description: >- - A single turn in an interaction with an Agentic System. **OR** streamed - agent turn completion response. + A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk + objects. content: application/json: schema: @@ -1417,20 +1417,28 @@ paths: $ref: '#/components/schemas/AgentTurnResponseStreamChunk' tags: - Agents - description: '' + description: >- + Resume an agent turn with executed tool call responses. + + When a Turn has the status `awaiting_input` due to pending input from client + side tool calls, this endpoint can be used to submit the outputs from the + tool calls once they are ready. parameters: - name: agent_id in: path + description: The ID of the agent to resume. required: true schema: type: string - name: session_id in: path + description: The ID of the session to resume. required: true schema: type: string - name: turn_id in: path + description: The ID of the turn to resume. required: true schema: type: string @@ -2779,6 +2787,8 @@ components: $ref: '#/components/schemas/AgentTool' tool_config: $ref: '#/components/schemas/ToolConfig' + allow_turn_resume: + type: boolean additionalProperties: false required: - messages @@ -5242,8 +5252,11 @@ components: type: array items: $ref: '#/components/schemas/ToolResponseMessage' + description: >- + The tool call responses to resume the turn with. stream: type: boolean + description: Whether to stream the response. additionalProperties: false required: - tool_responses diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 8fde864e4..c904fdbef 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -296,6 +296,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): stream: Optional[bool] = False tool_config: Optional[ToolConfig] = None + # TODO (xiyan): temporary flag, will remove for 0.1.5 + allow_turn_resume: Optional[bool] = False + @json_schema_type class AgentTurnResumeRequest(BaseModel): @@ -352,6 +355,7 @@ class Agents(Protocol): documents: Optional[List[Document]] = None, toolgroups: Optional[List[AgentToolGroup]] = None, tool_config: Optional[ToolConfig] = None, + allow_turn_resume: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod( @@ -365,7 +369,19 @@ class Agents(Protocol): turn_id: str, tool_responses: List[ToolResponseMessage], stream: Optional[bool] = False, - ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... + ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: + """Resume an agent turn with executed tool call responses. + + When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready. + + :param agent_id: The ID of the agent to resume. + :param session_id: The ID of the session to resume. + :param turn_id: The ID of the turn to resume. + :param tool_responses: The tool call responses to resume the turn with. + :param stream: Whether to stream the response. + :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. + """ + ... @webmethod( route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1c21df57f..edd253356 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -30,8 +30,10 @@ from llama_stack.apis.agents import ( AgentTurnResponseStepProgressPayload, AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, + AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, + AgentTurnResumeRequest, Attachment, Document, InferenceStep, @@ -62,7 +64,11 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO -from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + ToolCall, + ToolParamDefinition, +) from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing @@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) + async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]: + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) + + for turn in turns: + messages.extend(self.turn_to_messages(turn)) + return messages + async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: with tracing.span("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) @@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin): raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) - - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) - - for i, turn in enumerate(turns): - messages.extend(self.turn_to_messages(turn)) - + messages = await self.get_messages_from_turns(turns) messages.extend(request.messages) turn_id = str(uuid.uuid4()) @@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin): ) await self.storage.add_turn_to_session(request.session_id, turn) - chunk = AgentTurnResponseStreamChunk( + if output_message.tool_calls and request.allow_turn_resume: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + + yield chunk + + async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: + with tracing.span("resume_turn") as span: + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("session_id", request.session_id) + span.set_attribute("turn_id", request.turn_id) + span.set_attribute("request", request.model_dump_json()) + assert request.stream is True, "Non-streaming not supported" + + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") + + turns = await self.storage.get_session_turns(request.session_id) + messages = await self.get_messages_from_turns(turns) + messages.extend(request.tool_responses) + + last_turn_messages = [ + x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + ] + + # get the steps from the turn id + steps = [] + if len(turns) > 0: + steps = turns[-1].steps + + # mark tool execution step as complete + # if there's no tool execution in progress step (due to storage, or tool call parsing on client), + # we'll create a new tool execution step with current time + in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( + request.session_id, request.turn_id + ) + now = datetime.now() + tool_execution_step = ToolExecutionStep( + step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), + turn_id=request.turn_id, + tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), + tool_responses=[ + ToolResponse( + call_id=x.call_id, + tool_name=x.tool_name, + content=x.content, + ) + for x in request.tool_responses + ], + completed_at=now, + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), + ) + steps.append(tool_execution_step) + yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=tool_execution_step.step_id, + step_details=tool_execution_step, ) ) ) + + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=request.turn_id, + input_messages=messages, + sampling_params=self.agent_config.sampling_params, + stream=request.stream, + ): + if isinstance(chunk, CompletionMessage): + output_message = chunk + continue + + assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + event = chunk.event + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + steps.append(event.payload.step_details) + + yield chunk + + assert output_message is not None + + last_turn_start_time = datetime.now() + if len(turns) > 0: + last_turn_start_time = turns[-1].started_at + + turn = Turn( + turn_id=request.turn_id, + session_id=request.session_id, + input_messages=last_turn_messages, + output_message=output_message, + started_at=last_turn_start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + yield chunk async def run( @@ -611,11 +742,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message] else: log.info(f"{str(message)}") - tool_call = message.tool_calls[0] - if tool_call.tool_name in client_tools: - yield message - return - + # 1. Start the tool execution step and progress step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -625,6 +752,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) + tool_call = message.tool_calls[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -639,6 +767,23 @@ class ChatAgent(ShieldRunnerMixin): ) ) + # If tool is a client tool, yield CompletionMessage and return + if tool_call.tool_name in client_tools: + await self.storage.set_in_progress_tool_call_step( + session_id, + turn_id, + ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[tool_call], + tool_responses=[], + started_at=datetime.now(), + ), + ) + yield message + return + + # If tool is a builtin server tool, execute it tool_name = tool_call.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index acacbdfdf..8a4d91238 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -21,6 +21,7 @@ from llama_stack.apis.agents import ( AgentStepResponse, AgentToolGroup, AgentTurnCreateRequest, + AgentTurnResumeRequest, Document, Session, Turn, @@ -146,6 +147,7 @@ class MetaReferenceAgentsImpl(Agents): documents: Optional[List[Document]] = None, stream: Optional[bool] = False, tool_config: Optional[ToolConfig] = None, + allow_turn_resume: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, @@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents): toolgroups=toolgroups, documents=documents, tool_config=tool_config, + allow_turn_resume=allow_turn_resume, ) if stream: return self._create_agent_turn_streaming(request) @@ -177,7 +180,25 @@ class MetaReferenceAgentsImpl(Agents): tool_responses: List[ToolResponseMessage], stream: Optional[bool] = False, ) -> AsyncGenerator: - pass + request = AgentTurnResumeRequest( + agent_id=agent_id, + session_id=session_id, + turn_id=turn_id, + tool_responses=tool_responses, + stream=stream, + ) + if stream: + return self._continue_agent_turn_streaming(request) + else: + raise NotImplementedError("Non-streaming agent turns not yet implemented") + + async def _continue_agent_turn_streaming( + self, + request: AgentTurnResumeRequest, + ) -> AsyncGenerator: + agent = await self.get_agent(request.agent_id) + async for event in agent.resume_turn(request): + yield event async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 4b8ad6d4a..3c3866873 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -12,7 +12,7 @@ from typing import List, Optional from pydantic import BaseModel -from llama_stack.apis.agents import Turn +from llama_stack.apis.agents import ToolExecutionStep, Turn from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -84,3 +84,15 @@ class AgentPersistence: continue turns.sort(key=lambda x: (x.completed_at or datetime.min)) return turns + + async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): + await self.kvstore.set( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + value=step.model_dump_json(), + ) + + async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: + value = await self.kvstore.get( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + ) + return ToolExecutionStep(**json.loads(value)) if value else None diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5380d357..781095d2b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.tool_def_param import Parameter -from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig -from llama_stack.apis.agents.agents import ToolChoice +from llama_stack.apis.agents.agents import ( + AgentConfig as Server__AgentConfig, +) +from llama_stack.apis.agents.agents import ( + ToolChoice, +) class TestClientTool(ClientTool):