diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 80ef068c7..c662bac69 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -11,6 +11,7 @@ import uuid import warnings from collections.abc import AsyncGenerator from datetime import UTC, datetime +from typing import Any import httpx @@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin): ) def turn_to_messages(self, turn: Turn) -> list[Message]: - messages = [] + messages: list[Message] = [] # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages tool_call_ids = set() for step in turn.steps: - if step.step_type == StepType.tool_execution.value: + if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep): for response in step.tool_responses: tool_call_ids.add(response.call_id) @@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin): messages.append(msg) for step in turn.steps: - if step.step_type == StepType.inference.value: + if step.step_type == StepType.inference.value and isinstance(step, InferenceStep): messages.append(step.model_response) - elif step.step_type == StepType.tool_execution.value: + elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep): for response in step.tool_responses: messages.append( ToolResponseMessage( @@ -159,7 +160,7 @@ class ChatAgent(ShieldRunnerMixin): content=response.content, ) ) - elif step.step_type == StepType.shield_call.value: + elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep): if step.violation: # CompletionMessage itself in the ShieldResponse messages.append( @@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin): return await self.storage.create_session(name) async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: - messages = [] + messages: list[Message] = [] if self.agent_config.instructions != "": messages.append(SystemMessage(content=self.agent_config.instructions)) @@ -231,7 +232,13 @@ class ChatAgent(ShieldRunnerMixin): steps = [] messages = await self.get_messages_from_turns(turns) + + turn_id: str + start_time: datetime + input_messages: list[Message] + if is_resume: + assert isinstance(request, AgentTurnResumeRequest), "Expected AgentTurnResumeRequest for resume" tool_response_messages = [ ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses ] @@ -252,20 +259,21 @@ class ChatAgent(ShieldRunnerMixin): in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( request.session_id, request.turn_id ) - now = datetime.now(UTC).isoformat() + now_iso = datetime.now(UTC).isoformat() + now_dt = datetime.now(UTC) tool_execution_step = ToolExecutionStep( step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), turn_id=request.turn_id, tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), tool_responses=request.tool_responses, - completed_at=now, - started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), + completed_at=now_dt, + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt), ) steps.append(tool_execution_step) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=tool_execution_step.step_id, step_details=tool_execution_step, ) @@ -276,18 +284,22 @@ class ChatAgent(ShieldRunnerMixin): turn_id = request.turn_id start_time = last_turn.started_at else: + assert isinstance(request, AgentTurnCreateRequest), "Expected AgentTurnCreateRequest for create" messages.extend(request.messages) - start_time = datetime.now(UTC).isoformat() + start_time = datetime.now(UTC) input_messages = request.messages output_message = None + req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None + req_sampling = self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams() + async for chunk in self.run( session_id=request.session_id, turn_id=turn_id, input_messages=messages, - sampling_params=self.agent_config.sampling_params, + sampling_params=req_sampling, stream=request.stream, - documents=request.documents if not is_resume else None, + documents=req_documents, ): if isinstance(chunk, CompletionMessage): output_message = chunk @@ -295,8 +307,12 @@ class ChatAgent(ShieldRunnerMixin): assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" event = chunk.event - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: - steps.append(event.payload.step_details) + if ( + event.payload.event_type == AgentTurnResponseEventType.step_complete.value + and hasattr(event.payload, "step_details") + ): + step_details = getattr(event.payload, "step_details") + steps.append(step_details) yield chunk @@ -308,7 +324,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages=input_messages, output_message=output_message, started_at=start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), steps=steps, ) await self.storage.add_turn_to_session(request.session_id, turn)