refactor: move Authorization validation from API model to handler layer

Per reviewer feedback, API models should be pure data structures without
business logic. Moved the Authorization header validation from the Pydantic
@model_validator in openai_responses.py to the handler in streaming.py.

- Removed @model_validator from OpenAIResponseInputToolMCP
- Added validation at handler level in _process_mcp_tool()
- Maintains same security check: rejects Authorization in headers dict
- Follows separation of concerns: models are data, handlers have logic
This commit is contained in:
Omar Abdelwahab 2025-11-07 11:04:27 -08:00
parent 8ce30b71f4
commit 50040f3df7
2 changed files with 105 additions and 49 deletions

View file

@ -7,12 +7,12 @@
from collections.abc import Sequence
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator
from typing_extensions import TypedDict
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
from llama_stack.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator
from typing_extensions import TypedDict
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
# take their YAML and generate this file automatically. Their YAML is available.
@ -89,7 +89,9 @@ OpenAIResponseInputMessageContent = Annotated[
| OpenAIResponseInputMessageContentFile,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
register_schema(
OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent"
)
@json_schema_type
@ -191,7 +193,9 @@ OpenAIResponseOutputMessageContent = Annotated[
OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
register_schema(
OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent"
)
@json_schema_type
@ -203,8 +207,17 @@ class OpenAIResponseMessage(BaseModel):
scenarios.
"""
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
content: (
str
| Sequence[OpenAIResponseInputMessageContent]
| Sequence[OpenAIResponseOutputMessageContent]
)
role: (
Literal["system"]
| Literal["developer"]
| Literal["user"]
| Literal["assistant"]
)
type: Literal["message"] = "message"
# The fields below are not used in all scenarios, but are required in others.
@ -258,7 +271,9 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
queries: Sequence[str]
status: str
type: Literal["file_search_call"] = "file_search_call"
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = (
None
)
@json_schema_type
@ -403,7 +418,11 @@ class OpenAIResponseText(BaseModel):
# Must match type Literals of OpenAIResponseInputToolWebSearch below
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
WebSearchToolTypes = [
"web_search",
"web_search_preview",
"web_search_preview_2025_03_11",
]
@json_schema_type
@ -415,11 +434,15 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
"""
# Must match values of WebSearchToolTypes above
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
"web_search"
)
type: (
Literal["web_search"]
| Literal["web_search_preview"]
| Literal["web_search_preview_2025_03_11"]
) = "web_search"
# TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
search_context_size: str | None = Field(
default="medium", pattern="^low|medium|high$"
)
# TODO: add user_location
@ -502,22 +525,6 @@ class OpenAIResponseInputToolMCP(BaseModel):
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
allowed_tools: list[str] | AllowedToolsFilter | None = None
@model_validator(mode="after")
def validate_no_auth_in_headers(self) -> "OpenAIResponseInputToolMCP":
"""Ensure Authorization header is not passed via headers dict.
Authorization must be provided via the dedicated 'authorization' parameter
to ensure proper security handling and prevent token leakage in responses.
"""
if self.headers:
for key in self.headers.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'headers'. "
"Please use the 'authorization' parameter instead."
)
return self
OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch
@ -625,7 +632,9 @@ class OpenAIResponseObject(BaseModel):
temperature: float | None = None
# Default to text format to avoid breaking the loading of old responses
# before the field was added. New responses will have this set always.
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
text: OpenAIResponseText = OpenAIResponseText(
format=OpenAIResponseTextFormat(type="text")
)
top_p: float | None = None
tools: Sequence[OpenAIResponseTool] | None = None
truncation: str | None = None
@ -804,7 +813,9 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.function_call_arguments.delta"] = "response.function_call_arguments.delta"
type: Literal["response.function_call_arguments.delta"] = (
"response.function_call_arguments.delta"
)
@json_schema_type
@ -822,7 +833,9 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
type: Literal["response.function_call_arguments.done"] = (
"response.function_call_arguments.done"
)
@json_schema_type
@ -838,7 +851,9 @@ class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress"
type: Literal["response.web_search_call.in_progress"] = (
"response.web_search_call.in_progress"
)
@json_schema_type
@ -846,7 +861,9 @@ class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching"
type: Literal["response.web_search_call.searching"] = (
"response.web_search_call.searching"
)
@json_schema_type
@ -862,13 +879,17 @@ class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed"
type: Literal["response.web_search_call.completed"] = (
"response.web_search_call.completed"
)
@json_schema_type
class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel):
sequence_number: int
type: Literal["response.mcp_list_tools.in_progress"] = "response.mcp_list_tools.in_progress"
type: Literal["response.mcp_list_tools.in_progress"] = (
"response.mcp_list_tools.in_progress"
)
@json_schema_type
@ -880,7 +901,9 @@ class OpenAIResponseObjectStreamResponseMcpListToolsFailed(BaseModel):
@json_schema_type
class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel):
sequence_number: int
type: Literal["response.mcp_list_tools.completed"] = "response.mcp_list_tools.completed"
type: Literal["response.mcp_list_tools.completed"] = (
"response.mcp_list_tools.completed"
)
@json_schema_type
@ -889,7 +912,9 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.mcp_call.arguments.delta"] = "response.mcp_call.arguments.delta"
type: Literal["response.mcp_call.arguments.delta"] = (
"response.mcp_call.arguments.delta"
)
@json_schema_type
@ -898,7 +923,9 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.mcp_call.arguments.done"] = "response.mcp_call.arguments.done"
type: Literal["response.mcp_call.arguments.done"] = (
"response.mcp_call.arguments.done"
)
@json_schema_type
@ -970,7 +997,9 @@ class OpenAIResponseContentPartReasoningText(BaseModel):
OpenAIResponseContentPart = Annotated[
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText,
OpenAIResponseContentPartOutputText
| OpenAIResponseContentPartRefusal
| OpenAIResponseContentPartReasoningText,
Field(discriminator="type"),
]
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
@ -1089,7 +1118,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel):
part: OpenAIResponseContentPartReasoningSummary
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added"
type: Literal["response.reasoning_summary_part.added"] = (
"response.reasoning_summary_part.added"
)
@json_schema_type
@ -1109,7 +1140,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel):
part: OpenAIResponseContentPartReasoningSummary
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done"
type: Literal["response.reasoning_summary_part.done"] = (
"response.reasoning_summary_part.done"
)
@json_schema_type
@ -1129,7 +1162,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel):
output_index: int
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta"
type: Literal["response.reasoning_summary_text.delta"] = (
"response.reasoning_summary_text.delta"
)
@json_schema_type
@ -1149,7 +1184,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel):
output_index: int
sequence_number: int
summary_index: int
type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done"
type: Literal["response.reasoning_summary_text.done"] = (
"response.reasoning_summary_text.done"
)
@json_schema_type
@ -1211,7 +1248,9 @@ class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel):
annotation_index: int
annotation: OpenAIResponseAnnotations
sequence_number: int
type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
type: Literal["response.output_text.annotation.added"] = (
"response.output_text.annotation.added"
)
@json_schema_type
@ -1227,7 +1266,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress"
type: Literal["response.file_search_call.in_progress"] = (
"response.file_search_call.in_progress"
)
@json_schema_type
@ -1243,7 +1284,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching"
type: Literal["response.file_search_call.searching"] = (
"response.file_search_call.searching"
)
@json_schema_type
@ -1259,7 +1302,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel):
item_id: str
output_index: int
sequence_number: int
type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed"
type: Literal["response.file_search_call.completed"] = (
"response.file_search_call.completed"
)
OpenAIResponseObjectStream = Annotated[
@ -1350,7 +1395,9 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field."""
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
return OpenAIResponseObject(
**{k: v for k, v in self.model_dump().items() if k != "input"}
)
@json_schema_type

View file

@ -1055,6 +1055,15 @@ class StreamingResponseOrchestrator:
"""Process an MCP tool configuration and emit appropriate streaming events."""
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
# Validate that Authorization header is not passed via headers dict
if mcp_tool.headers:
for key in mcp_tool.headers.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'headers'. "
"Please use the 'authorization' parameter instead."
)
# Emit mcp_list_tools.in_progress
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(