diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index c662bac69..55bf31f57 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep): - if step.violation: + if step.violation and step.violation.user_message: # CompletionMessage itself in the ShieldResponse messages.append( CompletionMessage( @@ -233,12 +233,8 @@ class ChatAgent(ShieldRunnerMixin): steps = [] messages = await self.get_messages_from_turns(turns) - turn_id: str - start_time: datetime - input_messages: list[Message] - if is_resume: - assert isinstance(request, AgentTurnResumeRequest), "Expected AgentTurnResumeRequest for resume" + assert isinstance(request, AgentTurnResumeRequest) tool_response_messages = [ ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses ] @@ -259,7 +255,6 @@ class ChatAgent(ShieldRunnerMixin): in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( request.session_id, request.turn_id ) - now_iso = datetime.now(UTC).isoformat() now_dt = datetime.now(UTC) tool_execution_step = ToolExecutionStep( step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), @@ -279,23 +274,29 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - input_messages = last_turn.input_messages + # Cast needed due to list invariance - last_turn.input_messages is the right type + input_messages = last_turn.input_messages # type: ignore[assignment] - turn_id = request.turn_id + actual_turn_id = request.turn_id start_time = last_turn.started_at else: - assert isinstance(request, AgentTurnCreateRequest), "Expected AgentTurnCreateRequest for create" + assert isinstance(request, AgentTurnCreateRequest) messages.extend(request.messages) start_time = datetime.now(UTC) - input_messages = request.messages + # Cast needed due to list invariance - request.messages is the right type + input_messages = request.messages # type: ignore[assignment] + # Use the generated turn_id from beginning of function + actual_turn_id = turn_id if turn_id else str(uuid.uuid4()) output_message = None req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None - req_sampling = self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams() + req_sampling = ( + self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams() + ) async for chunk in self.run( session_id=request.session_id, - turn_id=turn_id, + turn_id=actual_turn_id, input_messages=messages, sampling_params=req_sampling, stream=request.stream, @@ -307,11 +308,10 @@ class ChatAgent(ShieldRunnerMixin): assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" event = chunk.event - if ( - event.payload.event_type == AgentTurnResponseEventType.step_complete.value - and hasattr(event.payload, "step_details") + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr( + event.payload, "step_details" ): - step_details = getattr(event.payload, "step_details") + step_details = event.payload.step_details steps.append(step_details) yield chunk @@ -319,9 +319,9 @@ class ChatAgent(ShieldRunnerMixin): assert output_message is not None turn = Turn( - turn_id=turn_id, + turn_id=actual_turn_id, session_id=request.session_id, - input_messages=input_messages, + input_messages=input_messages, # type: ignore[arg-type] output_message=output_message, started_at=start_time, completed_at=datetime.now(UTC), @@ -361,7 +361,7 @@ class ChatAgent(ShieldRunnerMixin): # return a "final value" for the `yield from` statement. we simulate that by yielding a # final boolean (to see whether an exception happened) and then explicitly testing for it. - if len(self.input_shields) > 0: + if self.input_shields: async for res in self.run_multiple_shields_wrapper( turn_id, input_messages, self.input_shields, "user-input" ): @@ -390,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin): # for output shields run on the full input and output combination messages = input_messages + [final_response] - if len(self.output_shields) > 0: + if self.output_shields: async for res in self.run_multiple_shields_wrapper( turn_id, messages, self.output_shields, "assistant-output" ): @@ -418,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin): return step_id = str(uuid.uuid4()) - shield_call_start_time = datetime.now(UTC).isoformat() + shield_call_start_time = datetime.now(UTC) try: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, metadata=dict(touchpoint=touchpoint), ) @@ -435,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, violation=e.violation, started_at=shield_call_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) @@ -459,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, violation=None, started_at=shield_call_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) @@ -512,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin): else: self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id) - output_attachments = [] + output_attachments: list[Attachment] = [] n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 # Build a map of custom tools to their definitions for faster lookup client_tools = {} - for tool in self.agent_config.client_tools: - client_tools[tool.name] = tool + if self.agent_config.client_tools: + for tool in self.agent_config.client_tools: + client_tools[tool.name] = tool while True: step_id = str(uuid.uuid4()) - inference_start_time = datetime.now(UTC).isoformat() + inference_start_time = datetime.now(UTC) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, ) ) @@ -554,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin): else: return value - def _add_type(openai_msg: dict) -> OpenAIMessageParam: + def _add_type(openai_msg: Any) -> OpenAIMessageParam: # Serialize any nested Pydantic models to plain dicts openai_msg = _serialize_nested(openai_msg) @@ -604,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin): messages=openai_messages, tools=openai_tools if openai_tools else None, tool_choice=tool_choice, - response_format=self.agent_config.response_format, + response_format=self.agent_config.response_format, # type: ignore[arg-type] temperature=temperature, top_p=top_p, max_tokens=max_tokens, @@ -614,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin): # Convert OpenAI stream back to Llama Stack format response_stream = convert_openai_chat_completion_stream( - openai_stream, enable_incremental_tool_calls=True + openai_stream, # type: ignore[arg-type] + enable_incremental_tool_calls=True, ) async for chunk in response_stream: @@ -636,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, delta=delta, ) @@ -649,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, delta=delta, ) @@ -667,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin): output_attr = json.dumps( { "content": content, - "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], + "tool_calls": [ + json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall) + ], } ) span.set_attribute("output", output_attr) @@ -683,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin): if tool_calls: content = "" + # Filter out string tool calls for CompletionMessage (only keep ToolCall objects) + valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)] message = CompletionMessage( content=content, stop_reason=stop_reason, - tool_calls=tool_calls, + tool_calls=valid_tool_calls if valid_tool_calls else None, ) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, step_details=InferenceStep( # somewhere deep, we are re-assigning message or closing over some @@ -702,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin): turn_id=turn_id, model_response=copy.deepcopy(message), started_at=inference_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) ) - if n_iter >= self.agent_config.max_infer_iters: + max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10 + if n_iter >= max_iters: logger.info(f"done with MAX iterations ({n_iter}), exiting.") # NOTE: mark end_of_turn to indicate to client that we are done with the turn # Do not continue the tool call loop after this point @@ -721,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin): yield message break - if len(message.tool_calls) == 0: + if not message.tool_calls or len(message.tool_calls) == 0: if stop_reason == StopReason.end_of_turn: # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) if len(output_attachments) > 0: if isinstance(message.content, list): - message.content += output_attachments + # List invariance - attachments are compatible at runtime + message.content += output_attachments # type: ignore[arg-type] else: - message.content = [message.content] + output_attachments + # List invariance - attachments are compatible at runtime + message.content = [message.content] + output_attachments # type: ignore[assignment] yield message else: logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") @@ -741,11 +750,12 @@ class ChatAgent(ShieldRunnerMixin): non_client_tool_calls = [] # Separate client and non-client tool calls - for tool_call in message.tool_calls: - if tool_call.tool_name in client_tools: - client_tool_calls.append(tool_call) - else: - non_client_tool_calls.append(tool_call) + if message.tool_calls: + for tool_call in message.tool_calls: + if tool_call.tool_name in client_tools: + client_tool_calls.append(tool_call) + else: + non_client_tool_calls.append(tool_call) # Process non-client tool calls first for tool_call in non_client_tool_calls: @@ -753,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, ) ) @@ -762,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, delta=ToolCallDelta( parse_status=ToolCallParseStatus.in_progress, @@ -782,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin): if self.telemetry_enabled else {}, ) as span: - tool_execution_start_time = datetime.now(UTC).isoformat() + tool_execution_start_time = datetime.now(UTC) tool_result = await self.execute_tool_call_maybe( session_id, tool_call, @@ -812,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin): ) ], started_at=tool_execution_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ) # Yield the step completion event yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, step_details=tool_execution_step, ) @@ -849,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin): turn_id=turn_id, tool_calls=client_tool_calls, tool_responses=[], - started_at=datetime.now(UTC).isoformat(), + started_at=datetime.now(UTC), ), ) @@ -884,19 +894,20 @@ class ChatAgent(ShieldRunnerMixin): toolgroup_to_args = toolgroup_to_args or {} - tool_name_to_def = {} + tool_name_to_def: dict[str, ToolDefinition] = {} tool_name_to_args = {} - for tool_def in self.agent_config.client_tools: - if tool_name_to_def.get(tool_def.name, None): - raise ValueError(f"Tool {tool_def.name} already exists") + if self.agent_config.client_tools: + for tool_def in self.agent_config.client_tools: + if tool_name_to_def.get(tool_def.name, None): + raise ValueError(f"Tool {tool_def.name} already exists") - # Use input_schema from ToolDef directly - tool_name_to_def[tool_def.name] = ToolDefinition( - tool_name=tool_def.name, - description=tool_def.description, - input_schema=tool_def.input_schema, - ) + # Use input_schema from ToolDef directly + tool_name_to_def[tool_def.name] = ToolDefinition( + tool_name=tool_def.name, + description=tool_def.description, + input_schema=tool_def.input_schema, + ) for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) @@ -924,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin): else: identifier = None - if tool_name_to_def.get(identifier, None): - raise ValueError(f"Tool {identifier} already exists") if identifier: - tool_name_to_def[identifier] = ToolDefinition( - tool_name=identifier, + # Convert BuiltinTool to string for dictionary key + identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier + if tool_name_to_def.get(identifier_str, None): + raise ValueError(f"Tool {identifier_str} already exists") + tool_name_to_def[identifier_str] = ToolDefinition( + tool_name=identifier_str, description=tool_def.description, input_schema=tool_def.input_schema, ) - tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {}) + tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {}) self.tool_defs, self.tool_name_to_args = ( list(tool_name_to_def.values()), @@ -1033,7 +1046,7 @@ def _interpret_content_as_attachment( snippet = match.group(1) data = json.loads(snippet) return Attachment( - url=URL(uri="file://" + data["filepath"]), + content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"], ) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 3a07a220e..82f58bae0 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -28,7 +28,6 @@ from llama_stack.apis.agents.openai_responses import ( ) from llama_stack.apis.common.content_types import ( ImageContentItem, - InterleavedContent, TextContentItem, ) from llama_stack.apis.inference import ( @@ -114,9 +113,7 @@ class ToolExecutor: final_output_message=output_message, final_input_message=input_message, citation_files=( - metadata.get("citation_files") - if result and (metadata := getattr(result, "metadata", None)) - else None + metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None ), ) @@ -399,9 +396,9 @@ class ToolExecutor: ) if error_exc: message.error = str(error_exc) - elif ( - result and (error_code := getattr(result, "error_code", None)) and error_code > 0 - ) or (result and (error_message := getattr(result, "error_message", None))): + elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or ( + result and getattr(result, "error_message", None) + ): ec = getattr(result, "error_code", "unknown") em = getattr(result, "error_message", "") message.error = f"Error (code {ec}): {em}"