diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 2fa1dab53..ef5603420 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -128,7 +128,9 @@ class StreamingResponseOrchestrator: self.prompt = prompt self.sequence_number = 0 # Store MCP tool mapping that gets built during tool processing - self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools if ctx.tool_context else {} + self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ( + ctx.tool_context.previous_tools if ctx.tool_context else {} + ) # Track final messages after all tool executions self.final_messages: list[OpenAIMessageParam] = [] # mapping for annotations @@ -1138,7 +1140,9 @@ class StreamingResponseOrchestrator: yield evt # Process all remaining tools (including MCP tools) and emit streaming events if self.ctx.tool_context.tools_to_process: - async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages): + async for stream_event in self._process_new_tools( + self.ctx.tool_context.tools_to_process, output_messages + ): yield stream_event def _approval_required(self, tool_name: str) -> bool: diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index bca35a44c..1a36998fc 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -330,8 +330,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[ # Look up shields to get their provider_resource_id (actual model ID) model_ids = [] - # list_shields not in Safety interface but available at runtime via API routing - shields_list = await safety_api.list_shields() # type: ignore[attr-defined] + # TODO: list_shields not in Safety interface but available at runtime via API routing + shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined] for guardrail_id in guardrail_ids: matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id] @@ -348,9 +348,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[ for result in response.results: if result.flagged: message = result.user_message or "Content blocked by safety guardrails" - flagged_categories = [ - cat for cat, flagged in result.categories.items() if flagged - ] if result.categories else [] + flagged_categories = ( + [cat for cat, flagged in result.categories.items() if flagged] if result.categories else [] + ) violation_type = result.metadata.get("violation_type", []) if result.metadata else [] if flagged_categories: