From 50040f3df7ff7eabc37be03b2101f21b46c3a1a3 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Fri, 7 Nov 2025 11:04:27 -0800 Subject: [PATCH] refactor: move Authorization validation from API model to handler layer Per reviewer feedback, API models should be pure data structures without business logic. Moved the Authorization header validation from the Pydantic @model_validator in openai_responses.py to the handler in streaming.py. - Removed @model_validator from OpenAIResponseInputToolMCP - Added validation at handler level in _process_mcp_tool() - Maintains same security check: rejects Authorization in headers dict - Follows separation of concerns: models are data, handlers have logic --- .../apis/agents/openai_responses.py | 145 ++++++++++++------ .../meta_reference/responses/streaming.py | 9 ++ 2 files changed, 105 insertions(+), 49 deletions(-) diff --git a/src/llama_stack/apis/agents/openai_responses.py b/src/llama_stack/apis/agents/openai_responses.py index d576f51d1..92b0b7f3b 100644 --- a/src/llama_stack/apis/agents/openai_responses.py +++ b/src/llama_stack/apis/agents/openai_responses.py @@ -7,12 +7,12 @@ from collections.abc import Sequence from typing import Annotated, Any, Literal -from pydantic import BaseModel, Field, model_validator -from typing_extensions import TypedDict - from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions from llama_stack.schema_utils import json_schema_type, register_schema +from pydantic import BaseModel, Field, model_validator +from typing_extensions import TypedDict + # NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably # take their YAML and generate this file automatically. Their YAML is available. @@ -89,7 +89,9 @@ OpenAIResponseInputMessageContent = Annotated[ | OpenAIResponseInputMessageContentFile, Field(discriminator="type"), ] -register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent") +register_schema( + OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent" +) @json_schema_type @@ -191,7 +193,9 @@ OpenAIResponseOutputMessageContent = Annotated[ OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal, Field(discriminator="type"), ] -register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent") +register_schema( + OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent" +) @json_schema_type @@ -203,8 +207,17 @@ class OpenAIResponseMessage(BaseModel): scenarios. """ - content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent] - role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"] + content: ( + str + | Sequence[OpenAIResponseInputMessageContent] + | Sequence[OpenAIResponseOutputMessageContent] + ) + role: ( + Literal["system"] + | Literal["developer"] + | Literal["user"] + | Literal["assistant"] + ) type: Literal["message"] = "message" # The fields below are not used in all scenarios, but are required in others. @@ -258,7 +271,9 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): queries: Sequence[str] status: str type: Literal["file_search_call"] = "file_search_call" - results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None + results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = ( + None + ) @json_schema_type @@ -403,7 +418,11 @@ class OpenAIResponseText(BaseModel): # Must match type Literals of OpenAIResponseInputToolWebSearch below -WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"] +WebSearchToolTypes = [ + "web_search", + "web_search_preview", + "web_search_preview_2025_03_11", +] @json_schema_type @@ -415,11 +434,15 @@ class OpenAIResponseInputToolWebSearch(BaseModel): """ # Must match values of WebSearchToolTypes above - type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = ( - "web_search" - ) + type: ( + Literal["web_search"] + | Literal["web_search_preview"] + | Literal["web_search_preview_2025_03_11"] + ) = "web_search" # TODO: actually use search_context_size somewhere... - search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$") + search_context_size: str | None = Field( + default="medium", pattern="^low|medium|high$" + ) # TODO: add user_location @@ -502,22 +525,6 @@ class OpenAIResponseInputToolMCP(BaseModel): require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never" allowed_tools: list[str] | AllowedToolsFilter | None = None - @model_validator(mode="after") - def validate_no_auth_in_headers(self) -> "OpenAIResponseInputToolMCP": - """Ensure Authorization header is not passed via headers dict. - - Authorization must be provided via the dedicated 'authorization' parameter - to ensure proper security handling and prevent token leakage in responses. - """ - if self.headers: - for key in self.headers.keys(): - if key.lower() == "authorization": - raise ValueError( - "Authorization header cannot be passed via 'headers'. " - "Please use the 'authorization' parameter instead." - ) - return self - OpenAIResponseInputTool = Annotated[ OpenAIResponseInputToolWebSearch @@ -625,7 +632,9 @@ class OpenAIResponseObject(BaseModel): temperature: float | None = None # Default to text format to avoid breaking the loading of old responses # before the field was added. New responses will have this set always. - text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) + text: OpenAIResponseText = OpenAIResponseText( + format=OpenAIResponseTextFormat(type="text") + ) top_p: float | None = None tools: Sequence[OpenAIResponseTool] | None = None truncation: str | None = None @@ -804,7 +813,9 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.function_call_arguments.delta"] = "response.function_call_arguments.delta" + type: Literal["response.function_call_arguments.delta"] = ( + "response.function_call_arguments.delta" + ) @json_schema_type @@ -822,7 +833,9 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done" + type: Literal["response.function_call_arguments.done"] = ( + "response.function_call_arguments.done" + ) @json_schema_type @@ -838,7 +851,9 @@ class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress" + type: Literal["response.web_search_call.in_progress"] = ( + "response.web_search_call.in_progress" + ) @json_schema_type @@ -846,7 +861,9 @@ class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching" + type: Literal["response.web_search_call.searching"] = ( + "response.web_search_call.searching" + ) @json_schema_type @@ -862,13 +879,17 @@ class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed" + type: Literal["response.web_search_call.completed"] = ( + "response.web_search_call.completed" + ) @json_schema_type class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel): sequence_number: int - type: Literal["response.mcp_list_tools.in_progress"] = "response.mcp_list_tools.in_progress" + type: Literal["response.mcp_list_tools.in_progress"] = ( + "response.mcp_list_tools.in_progress" + ) @json_schema_type @@ -880,7 +901,9 @@ class OpenAIResponseObjectStreamResponseMcpListToolsFailed(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel): sequence_number: int - type: Literal["response.mcp_list_tools.completed"] = "response.mcp_list_tools.completed" + type: Literal["response.mcp_list_tools.completed"] = ( + "response.mcp_list_tools.completed" + ) @json_schema_type @@ -889,7 +912,9 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.mcp_call.arguments.delta"] = "response.mcp_call.arguments.delta" + type: Literal["response.mcp_call.arguments.delta"] = ( + "response.mcp_call.arguments.delta" + ) @json_schema_type @@ -898,7 +923,9 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.mcp_call.arguments.done"] = "response.mcp_call.arguments.done" + type: Literal["response.mcp_call.arguments.done"] = ( + "response.mcp_call.arguments.done" + ) @json_schema_type @@ -970,7 +997,9 @@ class OpenAIResponseContentPartReasoningText(BaseModel): OpenAIResponseContentPart = Annotated[ - OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText, + OpenAIResponseContentPartOutputText + | OpenAIResponseContentPartRefusal + | OpenAIResponseContentPartReasoningText, Field(discriminator="type"), ] register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart") @@ -1089,7 +1118,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel): part: OpenAIResponseContentPartReasoningSummary sequence_number: int summary_index: int - type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added" + type: Literal["response.reasoning_summary_part.added"] = ( + "response.reasoning_summary_part.added" + ) @json_schema_type @@ -1109,7 +1140,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel): part: OpenAIResponseContentPartReasoningSummary sequence_number: int summary_index: int - type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done" + type: Literal["response.reasoning_summary_part.done"] = ( + "response.reasoning_summary_part.done" + ) @json_schema_type @@ -1129,7 +1162,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel): output_index: int sequence_number: int summary_index: int - type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta" + type: Literal["response.reasoning_summary_text.delta"] = ( + "response.reasoning_summary_text.delta" + ) @json_schema_type @@ -1149,7 +1184,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel): output_index: int sequence_number: int summary_index: int - type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done" + type: Literal["response.reasoning_summary_text.done"] = ( + "response.reasoning_summary_text.done" + ) @json_schema_type @@ -1211,7 +1248,9 @@ class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel): annotation_index: int annotation: OpenAIResponseAnnotations sequence_number: int - type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added" + type: Literal["response.output_text.annotation.added"] = ( + "response.output_text.annotation.added" + ) @json_schema_type @@ -1227,7 +1266,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress" + type: Literal["response.file_search_call.in_progress"] = ( + "response.file_search_call.in_progress" + ) @json_schema_type @@ -1243,7 +1284,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching" + type: Literal["response.file_search_call.searching"] = ( + "response.file_search_call.searching" + ) @json_schema_type @@ -1259,7 +1302,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel): item_id: str output_index: int sequence_number: int - type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed" + type: Literal["response.file_search_call.completed"] = ( + "response.file_search_call.completed" + ) OpenAIResponseObjectStream = Annotated[ @@ -1350,7 +1395,9 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject): def to_response_object(self) -> OpenAIResponseObject: """Convert to OpenAIResponseObject by excluding input field.""" - return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"}) + return OpenAIResponseObject( + **{k: v for k, v in self.model_dump().items() if k != "input"} + ) @json_schema_type diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index ea98d19cd..c9657e361 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -1055,6 +1055,15 @@ class StreamingResponseOrchestrator: """Process an MCP tool configuration and emit appropriate streaming events.""" from llama_stack.providers.utils.tools.mcp import list_mcp_tools + # Validate that Authorization header is not passed via headers dict + if mcp_tool.headers: + for key in mcp_tool.headers.keys(): + if key.lower() == "authorization": + raise ValueError( + "Authorization header cannot be passed via 'headers'. " + "Please use the 'authorization' parameter instead." + ) + # Emit mcp_list_tools.in_progress self.sequence_number += 1 yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(