refactor: move Authorization validation from API model to handler layer

Per reviewer feedback, API models should be pure data structures without
business logic. Moved the Authorization header validation from the Pydantic
@model_validator in openai_responses.py to the handler in streaming.py.

- Removed @model_validator from OpenAIResponseInputToolMCP
- Added validation at handler level in _process_mcp_tool()
- Maintains same security check: rejects Authorization in headers dict
- Follows separation of concerns: models are data, handlers have logic
This commit is contained in:
Omar Abdelwahab 2025-11-07 11:04:27 -08:00
parent 8ce30b71f4
commit 50040f3df7
2 changed files with 105 additions and 49 deletions

View file

@ -7,12 +7,12 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator
from typing_extensions import TypedDict
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
from llama_stack.schema_utils import json_schema_type, register_schema from llama_stack.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator
from typing_extensions import TypedDict
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably # NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
# take their YAML and generate this file automatically. Their YAML is available. # take their YAML and generate this file automatically. Their YAML is available.
@ -89,7 +89,9 @@ OpenAIResponseInputMessageContent = Annotated[
| OpenAIResponseInputMessageContentFile, | OpenAIResponseInputMessageContentFile,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent") register_schema(
OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent"
)
@json_schema_type @json_schema_type
@ -191,7 +193,9 @@ OpenAIResponseOutputMessageContent = Annotated[
OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal, OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent") register_schema(
OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent"
)
@json_schema_type @json_schema_type
@ -203,8 +207,17 @@ class OpenAIResponseMessage(BaseModel):
scenarios. scenarios.
""" """
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent] content: (
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"] str
| Sequence[OpenAIResponseInputMessageContent]
| Sequence[OpenAIResponseOutputMessageContent]
)
role: (
Literal["system"]
| Literal["developer"]
| Literal["user"]
| Literal["assistant"]
)
type: Literal["message"] = "message" type: Literal["message"] = "message"
# The fields below are not used in all scenarios, but are required in others. # The fields below are not used in all scenarios, but are required in others.
@ -258,7 +271,9 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
queries: Sequence[str] queries: Sequence[str]
status: str status: str
type: Literal["file_search_call"] = "file_search_call" type: Literal["file_search_call"] = "file_search_call"
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = (
None
)
@json_schema_type @json_schema_type
@ -403,7 +418,11 @@ class OpenAIResponseText(BaseModel):
# Must match type Literals of OpenAIResponseInputToolWebSearch below # Must match type Literals of OpenAIResponseInputToolWebSearch below
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"] WebSearchToolTypes = [
"web_search",
"web_search_preview",
"web_search_preview_2025_03_11",
]
@json_schema_type @json_schema_type
@ -415,11 +434,15 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
""" """
# Must match values of WebSearchToolTypes above # Must match values of WebSearchToolTypes above
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = ( type: (
"web_search" Literal["web_search"]
) | Literal["web_search_preview"]
| Literal["web_search_preview_2025_03_11"]
) = "web_search"
# TODO: actually use search_context_size somewhere... # TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$") search_context_size: str | None = Field(
default="medium", pattern="^low|medium|high$"
)
# TODO: add user_location # TODO: add user_location
@ -502,22 +525,6 @@ class OpenAIResponseInputToolMCP(BaseModel):
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never" require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
allowed_tools: list[str] | AllowedToolsFilter | None = None allowed_tools: list[str] | AllowedToolsFilter | None = None
@model_validator(mode="after")
def validate_no_auth_in_headers(self) -> "OpenAIResponseInputToolMCP":
"""Ensure Authorization header is not passed via headers dict.
Authorization must be provided via the dedicated 'authorization' parameter
to ensure proper security handling and prevent token leakage in responses.
"""
if self.headers:
for key in self.headers.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'headers'. "
"Please use the 'authorization' parameter instead."
)
return self
OpenAIResponseInputTool = Annotated[ OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch OpenAIResponseInputToolWebSearch
@ -625,7 +632,9 @@ class OpenAIResponseObject(BaseModel):
temperature: float | None = None temperature: float | None = None
# Default to text format to avoid breaking the loading of old responses # Default to text format to avoid breaking the loading of old responses
# before the field was added. New responses will have this set always. # before the field was added. New responses will have this set always.
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) text: OpenAIResponseText = OpenAIResponseText(
format=OpenAIResponseTextFormat(type="text")
)
top_p: float | None = None top_p: float | None = None
tools: Sequence[OpenAIResponseTool] | None = None tools: Sequence[OpenAIResponseTool] | None = None
truncation: str | None = None truncation: str | None = None
@ -804,7 +813,9 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.function_call_arguments.delta"] = "response.function_call_arguments.delta" type: Literal["response.function_call_arguments.delta"] = (
"response.function_call_arguments.delta"
)
@json_schema_type @json_schema_type
@ -822,7 +833,9 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done" type: Literal["response.function_call_arguments.done"] = (
"response.function_call_arguments.done"
)
@json_schema_type @json_schema_type
@ -838,7 +851,9 @@ class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress" type: Literal["response.web_search_call.in_progress"] = (
"response.web_search_call.in_progress"
)
@json_schema_type @json_schema_type
@ -846,7 +861,9 @@ class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching" type: Literal["response.web_search_call.searching"] = (
"response.web_search_call.searching"
)
@json_schema_type @json_schema_type
@ -862,13 +879,17 @@ class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed" type: Literal["response.web_search_call.completed"] = (
"response.web_search_call.completed"
)
@json_schema_type @json_schema_type
class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel): class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel):
sequence_number: int sequence_number: int
type: Literal["response.mcp_list_tools.in_progress"] = "response.mcp_list_tools.in_progress" type: Literal["response.mcp_list_tools.in_progress"] = (
"response.mcp_list_tools.in_progress"
)
@json_schema_type @json_schema_type
@ -880,7 +901,9 @@ class OpenAIResponseObjectStreamResponseMcpListToolsFailed(BaseModel):
@json_schema_type @json_schema_type
class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel): class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel):
sequence_number: int sequence_number: int
type: Literal["response.mcp_list_tools.completed"] = "response.mcp_list_tools.completed" type: Literal["response.mcp_list_tools.completed"] = (
"response.mcp_list_tools.completed"
)
@json_schema_type @json_schema_type
@ -889,7 +912,9 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.mcp_call.arguments.delta"] = "response.mcp_call.arguments.delta" type: Literal["response.mcp_call.arguments.delta"] = (
"response.mcp_call.arguments.delta"
)
@json_schema_type @json_schema_type
@ -898,7 +923,9 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.mcp_call.arguments.done"] = "response.mcp_call.arguments.done" type: Literal["response.mcp_call.arguments.done"] = (
"response.mcp_call.arguments.done"
)
@json_schema_type @json_schema_type
@ -970,7 +997,9 @@ class OpenAIResponseContentPartReasoningText(BaseModel):
OpenAIResponseContentPart = Annotated[ OpenAIResponseContentPart = Annotated[
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText, OpenAIResponseContentPartOutputText
| OpenAIResponseContentPartRefusal
| OpenAIResponseContentPartReasoningText,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart") register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
@ -1089,7 +1118,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel):
part: OpenAIResponseContentPartReasoningSummary part: OpenAIResponseContentPartReasoningSummary
sequence_number: int sequence_number: int
summary_index: int summary_index: int
type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added" type: Literal["response.reasoning_summary_part.added"] = (
"response.reasoning_summary_part.added"
)
@json_schema_type @json_schema_type
@ -1109,7 +1140,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel):
part: OpenAIResponseContentPartReasoningSummary part: OpenAIResponseContentPartReasoningSummary
sequence_number: int sequence_number: int
summary_index: int summary_index: int
type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done" type: Literal["response.reasoning_summary_part.done"] = (
"response.reasoning_summary_part.done"
)
@json_schema_type @json_schema_type
@ -1129,7 +1162,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel):
output_index: int output_index: int
sequence_number: int sequence_number: int
summary_index: int summary_index: int
type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta" type: Literal["response.reasoning_summary_text.delta"] = (
"response.reasoning_summary_text.delta"
)
@json_schema_type @json_schema_type
@ -1149,7 +1184,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel):
output_index: int output_index: int
sequence_number: int sequence_number: int
summary_index: int summary_index: int
type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done" type: Literal["response.reasoning_summary_text.done"] = (
"response.reasoning_summary_text.done"
)
@json_schema_type @json_schema_type
@ -1211,7 +1248,9 @@ class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel):
annotation_index: int annotation_index: int
annotation: OpenAIResponseAnnotations annotation: OpenAIResponseAnnotations
sequence_number: int sequence_number: int
type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added" type: Literal["response.output_text.annotation.added"] = (
"response.output_text.annotation.added"
)
@json_schema_type @json_schema_type
@ -1227,7 +1266,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress" type: Literal["response.file_search_call.in_progress"] = (
"response.file_search_call.in_progress"
)
@json_schema_type @json_schema_type
@ -1243,7 +1284,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching" type: Literal["response.file_search_call.searching"] = (
"response.file_search_call.searching"
)
@json_schema_type @json_schema_type
@ -1259,7 +1302,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed" type: Literal["response.file_search_call.completed"] = (
"response.file_search_call.completed"
)
OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStream = Annotated[
@ -1350,7 +1395,9 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
def to_response_object(self) -> OpenAIResponseObject: def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field.""" """Convert to OpenAIResponseObject by excluding input field."""
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"}) return OpenAIResponseObject(
**{k: v for k, v in self.model_dump().items() if k != "input"}
)
@json_schema_type @json_schema_type

View file

@ -1055,6 +1055,15 @@ class StreamingResponseOrchestrator:
"""Process an MCP tool configuration and emit appropriate streaming events.""" """Process an MCP tool configuration and emit appropriate streaming events."""
from llama_stack.providers.utils.tools.mcp import list_mcp_tools from llama_stack.providers.utils.tools.mcp import list_mcp_tools
# Validate that Authorization header is not passed via headers dict
if mcp_tool.headers:
for key in mcp_tool.headers.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'headers'. "
"Please use the 'authorization' parameter instead."
)
# Emit mcp_list_tools.in_progress # Emit mcp_list_tools.in_progress
self.sequence_number += 1 self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress( yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(