From 1d52c303d156d6ee03c4a4fa8dd94f96ad637383 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Tue, 8 Jul 2025 21:11:45 +0200 Subject: [PATCH] chore(api): add mypy coverage to meta_reference_agent_instance Signed-off-by: Mustafa Elbehery --- .../agents/meta_reference/agent_instance.py | 181 ++++++++++-------- .../inline/agents/meta_reference/safety.py | 4 +- pyproject.toml | 1 - 3 files changed, 105 insertions(+), 81 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 4d2b9f8bf..b8e377c27 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -12,6 +12,7 @@ import string import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime +from typing import Any import httpx @@ -73,7 +74,7 @@ from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin -def make_random_string(length: int = 8): +def make_random_string(length: int = 8) -> str: return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) @@ -117,14 +118,15 @@ class ChatAgent(ShieldRunnerMixin): ) def turn_to_messages(self, turn: Turn) -> list[Message]: - messages = [] + messages: list[Message] = [] # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages tool_call_ids = set() for step in turn.steps: if step.step_type == StepType.tool_execution.value: - for response in step.tool_responses: - tool_call_ids.add(response.call_id) + if isinstance(step, ToolExecutionStep): + for response in step.tool_responses: + tool_call_ids.add(response.call_id) for m in turn.input_messages: msg = m.model_copy() @@ -142,31 +144,34 @@ class ChatAgent(ShieldRunnerMixin): for step in turn.steps: if step.step_type == StepType.inference.value: - messages.append(step.model_response) + if isinstance(step, InferenceStep): + messages.append(step.model_response) elif step.step_type == StepType.tool_execution.value: - for response in step.tool_responses: - messages.append( - ToolResponseMessage( - call_id=response.call_id, - content=response.content, + if isinstance(step, ToolExecutionStep): + for response in step.tool_responses: + messages.append( + ToolResponseMessage( + call_id=response.call_id, + content=response.content, + ) ) - ) elif step.step_type == StepType.shield_call.value: - if step.violation: - # CompletionMessage itself in the ShieldResponse - messages.append( - CompletionMessage( - content=step.violation.user_message, - stop_reason=StopReason.end_of_turn, + if isinstance(step, ShieldCallStep): + if step.violation: + # CompletionMessage itself in the ShieldResponse + messages.append( + CompletionMessage( + content=step.violation.user_message or "", + stop_reason=StopReason.end_of_turn, + ) ) - ) return messages async def create_session(self, name: str) -> str: return await self.storage.create_session(name) async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: - messages = [] + messages: list[Message] = [] if self.agent_config.instructions != "": messages.append(SystemMessage(content=self.agent_config.instructions)) @@ -174,7 +179,9 @@ class ChatAgent(ShieldRunnerMixin): messages.extend(self.turn_to_messages(turn)) return messages - async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: + async def create_and_execute_turn( + self, request: AgentTurnCreateRequest + ) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]: span = tracing.get_current_span() if span: span.set_attribute("session_id", request.session_id) @@ -189,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in self._run_turn(request, turn_id): yield chunk - async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: + async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]: span = tracing.get_current_span() if span: span.set_attribute("agent_id", self.agent_id) @@ -207,7 +214,7 @@ class ChatAgent(ShieldRunnerMixin): self, request: AgentTurnCreateRequest | AgentTurnResumeRequest, turn_id: str | None = None, - ) -> AsyncGenerator: + ) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]: assert request.stream is True, "Non-streaming not supported" is_resume = isinstance(request, AgentTurnResumeRequest) @@ -271,11 +278,13 @@ class ChatAgent(ShieldRunnerMixin): input_messages = request.messages output_message = None + turn_id_final = turn_id or str(uuid.uuid4()) + sampling_params = self.agent_config.sampling_params or SamplingParams() async for chunk in self.run( session_id=request.session_id, - turn_id=turn_id, + turn_id=turn_id_final, input_messages=messages, - sampling_params=self.agent_config.sampling_params, + sampling_params=sampling_params, stream=request.stream, documents=request.documents if not is_resume else None, ): @@ -286,19 +295,20 @@ class ChatAgent(ShieldRunnerMixin): assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" event = chunk.event if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: - steps.append(event.payload.step_details) + if hasattr(event.payload, "step_details"): + steps.append(event.payload.step_details) yield chunk assert output_message is not None turn = Turn( - turn_id=turn_id, + turn_id=turn_id_final, session_id=request.session_id, input_messages=input_messages, output_message=output_message, started_at=start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), steps=steps, ) await self.storage.add_turn_to_session(request.session_id, turn) @@ -329,7 +339,7 @@ class ChatAgent(ShieldRunnerMixin): sampling_params: SamplingParams, stream: bool = False, documents: list[Document] | None = None, - ) -> AsyncGenerator: + ) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage, None]: # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot # return a "final value" for the `yield from` statement. we simulate that by yielding a @@ -381,7 +391,7 @@ class ChatAgent(ShieldRunnerMixin): messages: list[Message], shields: list[str], touchpoint: str, - ) -> AsyncGenerator: + ) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]: async with tracing.span("run_shields") as span: span.set_attribute("input", [m.model_dump_json() for m in messages]) if len(shields) == 0: @@ -389,12 +399,12 @@ class ChatAgent(ShieldRunnerMixin): return step_id = str(uuid.uuid4()) - shield_call_start_time = datetime.now(UTC).isoformat() + shield_call_start_time = datetime.now(UTC) try: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, metadata=dict(touchpoint=touchpoint), ) @@ -406,14 +416,14 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, violation=e.violation, started_at=shield_call_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) @@ -429,14 +439,14 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, violation=None, started_at=shield_call_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) @@ -451,7 +461,7 @@ class ChatAgent(ShieldRunnerMixin): sampling_params: SamplingParams, stream: bool = False, documents: list[Document] | None = None, - ) -> AsyncGenerator: + ) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]: # if document is passed in a turn, we parse the raw text of the document # and sent it as a user message if documents: @@ -481,43 +491,46 @@ class ChatAgent(ShieldRunnerMixin): else: self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id) - output_attachments = [] + output_attachments: list[Attachment] = [] n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 # Build a map of custom tools to their definitions for faster lookup - client_tools = {} - for tool in self.agent_config.client_tools: + client_tools: dict[str, Any] = {} + for tool in self.agent_config.client_tools or []: client_tools[tool.name] = tool while True: step_id = str(uuid.uuid4()) - inference_start_time = datetime.now(UTC).isoformat() + inference_start_time = datetime.now(UTC) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, ) ) ) - tool_calls = [] + tool_calls: list[ToolCall] = [] content = "" stop_reason = None async with tracing.span("inference") as span: if self.agent_config.name: span.set_attribute("agent_name", self.agent_config.name) - async for chunk in await self.inference_api.chat_completion( + chat_completion_response = await self.inference_api.chat_completion( self.agent_config.model, input_messages, tools=self.tool_defs, - tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, + tool_prompt_format=self.agent_config.tool_config.tool_prompt_format + if self.agent_config.tool_config + else None, response_format=self.agent_config.response_format, stream=True, sampling_params=sampling_params, tool_config=self.agent_config.tool_config, - ): + ) + async for chunk in chat_completion_response: event = chunk.event if event.event_type == ChatCompletionResponseEventType.start: continue @@ -527,16 +540,18 @@ class ChatAgent(ShieldRunnerMixin): delta = event.delta if delta.type == "tool_call": - if delta.parse_status == ToolCallParseStatus.succeeded: - tool_calls.append(delta.tool_call) - elif delta.parse_status == ToolCallParseStatus.failed: + if hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.succeeded: + if hasattr(delta, "tool_call"): + tool_calls.append(delta.tool_call) + elif hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.failed: # If we cannot parse the tools, set the content to the unparsed raw text - content = delta.tool_call + if hasattr(delta, "tool_call"): + content = str(delta.tool_call) if stream: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, delta=delta, ) @@ -544,12 +559,13 @@ class ChatAgent(ShieldRunnerMixin): ) elif delta.type == "text": - content += delta.text + if hasattr(delta, "text"): + content += delta.text if stream and event.stop_reason is None: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, delta=delta, ) @@ -565,10 +581,16 @@ class ChatAgent(ShieldRunnerMixin): "input", json.dumps([json.loads(m.model_dump_json()) for m in input_messages]), ) + tool_calls_json = [] + for t in tool_calls: + if hasattr(t, "model_dump_json"): + tool_calls_json.append(json.loads(t.model_dump_json())) + else: + tool_calls_json.append(str(t)) output_attr = json.dumps( { "content": content, - "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], + "tool_calls": tool_calls_json, } ) span.set_attribute("output", output_attr) @@ -593,7 +615,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, step_details=InferenceStep( # somewhere deep, we are re-assigning message or closing over some @@ -603,13 +625,14 @@ class ChatAgent(ShieldRunnerMixin): turn_id=turn_id, model_response=copy.deepcopy(message), started_at=inference_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) ) - if n_iter >= self.agent_config.max_infer_iters: + max_infer_iters = self.agent_config.max_infer_iters or 10 + if n_iter >= max_infer_iters: logger.info(f"done with MAX iterations ({n_iter}), exiting.") # NOTE: mark end_of_turn to indicate to client that we are done with the turn # Do not continue the tool call loop after this point @@ -622,7 +645,8 @@ class ChatAgent(ShieldRunnerMixin): yield message break - if len(message.tool_calls) == 0: + tool_calls_to_process = message.tool_calls or [] + if len(tool_calls_to_process) == 0: if stop_reason == StopReason.end_of_turn: # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) if len(output_attachments) > 0: @@ -642,7 +666,7 @@ class ChatAgent(ShieldRunnerMixin): non_client_tool_calls = [] # Separate client and non-client tool calls - for tool_call in message.tool_calls: + for tool_call in tool_calls_to_process: if tool_call.tool_name in client_tools: client_tool_calls.append(tool_call) else: @@ -654,7 +678,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, ) ) @@ -663,7 +687,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, delta=ToolCallDelta( parse_status=ToolCallParseStatus.in_progress, @@ -681,7 +705,7 @@ class ChatAgent(ShieldRunnerMixin): "input": message.model_dump_json(), }, ) as span: - tool_execution_start_time = datetime.now(UTC).isoformat() + tool_execution_start_time = datetime.now(UTC) tool_result = await self.execute_tool_call_maybe( session_id, tool_call, @@ -710,14 +734,14 @@ class ChatAgent(ShieldRunnerMixin): ) ], started_at=tool_execution_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ) # Yield the step completion event yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, step_details=tool_execution_step, ) @@ -747,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin): turn_id=turn_id, tool_calls=client_tool_calls, tool_responses=[], - started_at=datetime.now(UTC).isoformat(), + started_at=datetime.now(UTC), ), ) @@ -766,7 +790,7 @@ class ChatAgent(ShieldRunnerMixin): self, toolgroups_for_turn: list[AgentToolGroup] | None = None, ) -> None: - toolgroup_to_args = {} + toolgroup_to_args: dict[str, Any] = {} for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []): if isinstance(toolgroup, AgentToolGroupWithArgs): tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name) @@ -782,10 +806,10 @@ class ChatAgent(ShieldRunnerMixin): toolgroup_to_args = toolgroup_to_args or {} - tool_name_to_def = {} - tool_name_to_args = {} + tool_name_to_def: dict[str, ToolDefinition] = {} + tool_name_to_args: dict[str, Any] = {} - for tool_def in self.agent_config.client_tools: + for tool_def in self.agent_config.client_tools or []: if tool_name_to_def.get(tool_def.name, None): raise ValueError(f"Tool {tool_def.name} already exists") tool_name_to_def[tool_def.name] = ToolDefinition( @@ -798,7 +822,7 @@ class ChatAgent(ShieldRunnerMixin): required=param.required, default=param.default, ) - for param in tool_def.parameters + for param in tool_def.parameters or [] }, ) for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: @@ -828,7 +852,7 @@ class ChatAgent(ShieldRunnerMixin): else: identifier = None - if tool_name_to_def.get(identifier, None): + if identifier is not None and tool_name_to_def.get(str(identifier), None): raise ValueError(f"Tool {identifier} already exists") if identifier: tool_name_to_def[tool_def.identifier] = ToolDefinition( @@ -841,7 +865,7 @@ class ChatAgent(ShieldRunnerMixin): required=param.required, default=param.default, ) - for param in tool_def.parameters + for param in tool_def.parameters or [] }, ) tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {}) @@ -888,14 +912,15 @@ class ChatAgent(ShieldRunnerMixin): tool_name_str = tool_name logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") + kwargs_dict = {"session_id": session_id} + if tool_call.arguments and isinstance(tool_call.arguments, dict): + kwargs_dict.update(tool_call.arguments) + tool_args = self.tool_name_to_args.get(tool_name_str, {}) + if tool_args: + kwargs_dict.update(tool_args) result = await self.tool_runtime_api.invoke_tool( tool_name=tool_name_str, - kwargs={ - "session_id": session_id, - # get the arguments generated by the model and augment with toolgroup arg overrides for the agent - **tool_call.arguments, - **self.tool_name_to_args.get(tool_name_str, {}), - }, + kwargs=kwargs_dict, ) logger.debug(f"tool call {tool_name_str} completed with result: {result}") return result @@ -931,7 +956,7 @@ def _interpret_content_as_attachment( snippet = match.group(1) data = json.loads(snippet) return Attachment( - url=URL(uri="file://" + data["filepath"]), + content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"], ) diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 605f387b7..85804c451 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -28,8 +28,8 @@ class ShieldRunnerMixin: output_shields: list[str] | None = None, ): self.safety_api = safety_api - self.input_shields = input_shields - self.output_shields = output_shields + self.input_shields = input_shields or [] + self.output_shields = output_shields or [] async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None: async def run_shield_with_span(identifier: str): diff --git a/pyproject.toml b/pyproject.toml index 30598e5e3..de8521a6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -243,7 +243,6 @@ exclude = [ "^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/providers/inline/agents/meta_reference/", - "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",