mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
refactor: move Authorization validation from API model to handler layer
Per reviewer feedback, API models should be pure data structures without business logic. Moved the Authorization header validation from the Pydantic @model_validator in openai_responses.py to the handler in streaming.py. - Removed @model_validator from OpenAIResponseInputToolMCP - Added validation at handler level in _process_mcp_tool() - Maintains same security check: rejects Authorization in headers dict - Follows separation of concerns: models are data, handlers have logic
This commit is contained in:
parent
8ce30b71f4
commit
50040f3df7
2 changed files with 105 additions and 49 deletions
|
|
@ -7,12 +7,12 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
||||
# take their YAML and generate this file automatically. Their YAML is available.
|
||||
|
||||
|
|
@ -89,7 +89,9 @@ OpenAIResponseInputMessageContent = Annotated[
|
|||
| OpenAIResponseInputMessageContentFile,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
register_schema(
|
||||
OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -191,7 +193,9 @@ OpenAIResponseOutputMessageContent = Annotated[
|
|||
OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||
register_schema(
|
||||
OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -203,8 +207,17 @@ class OpenAIResponseMessage(BaseModel):
|
|||
scenarios.
|
||||
"""
|
||||
|
||||
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
content: (
|
||||
str
|
||||
| Sequence[OpenAIResponseInputMessageContent]
|
||||
| Sequence[OpenAIResponseOutputMessageContent]
|
||||
)
|
||||
role: (
|
||||
Literal["system"]
|
||||
| Literal["developer"]
|
||||
| Literal["user"]
|
||||
| Literal["assistant"]
|
||||
)
|
||||
type: Literal["message"] = "message"
|
||||
|
||||
# The fields below are not used in all scenarios, but are required in others.
|
||||
|
|
@ -258,7 +271,9 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
|||
queries: Sequence[str]
|
||||
status: str
|
||||
type: Literal["file_search_call"] = "file_search_call"
|
||||
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = (
|
||||
None
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -403,7 +418,11 @@ class OpenAIResponseText(BaseModel):
|
|||
|
||||
|
||||
# Must match type Literals of OpenAIResponseInputToolWebSearch below
|
||||
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
|
||||
WebSearchToolTypes = [
|
||||
"web_search",
|
||||
"web_search_preview",
|
||||
"web_search_preview_2025_03_11",
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -415,11 +434,15 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
|
|||
"""
|
||||
|
||||
# Must match values of WebSearchToolTypes above
|
||||
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
|
||||
"web_search"
|
||||
)
|
||||
type: (
|
||||
Literal["web_search"]
|
||||
| Literal["web_search_preview"]
|
||||
| Literal["web_search_preview_2025_03_11"]
|
||||
) = "web_search"
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||
search_context_size: str | None = Field(
|
||||
default="medium", pattern="^low|medium|high$"
|
||||
)
|
||||
# TODO: add user_location
|
||||
|
||||
|
||||
|
|
@ -502,22 +525,6 @@ class OpenAIResponseInputToolMCP(BaseModel):
|
|||
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
|
||||
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_no_auth_in_headers(self) -> "OpenAIResponseInputToolMCP":
|
||||
"""Ensure Authorization header is not passed via headers dict.
|
||||
|
||||
Authorization must be provided via the dedicated 'authorization' parameter
|
||||
to ensure proper security handling and prevent token leakage in responses.
|
||||
"""
|
||||
if self.headers:
|
||||
for key in self.headers.keys():
|
||||
if key.lower() == "authorization":
|
||||
raise ValueError(
|
||||
"Authorization header cannot be passed via 'headers'. "
|
||||
"Please use the 'authorization' parameter instead."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch
|
||||
|
|
@ -625,7 +632,9 @@ class OpenAIResponseObject(BaseModel):
|
|||
temperature: float | None = None
|
||||
# Default to text format to avoid breaking the loading of old responses
|
||||
# before the field was added. New responses will have this set always.
|
||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
text: OpenAIResponseText = OpenAIResponseText(
|
||||
format=OpenAIResponseTextFormat(type="text")
|
||||
)
|
||||
top_p: float | None = None
|
||||
tools: Sequence[OpenAIResponseTool] | None = None
|
||||
truncation: str | None = None
|
||||
|
|
@ -804,7 +813,9 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.function_call_arguments.delta"] = "response.function_call_arguments.delta"
|
||||
type: Literal["response.function_call_arguments.delta"] = (
|
||||
"response.function_call_arguments.delta"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -822,7 +833,9 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
|
||||
type: Literal["response.function_call_arguments.done"] = (
|
||||
"response.function_call_arguments.done"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -838,7 +851,9 @@ class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress"
|
||||
type: Literal["response.web_search_call.in_progress"] = (
|
||||
"response.web_search_call.in_progress"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -846,7 +861,9 @@ class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching"
|
||||
type: Literal["response.web_search_call.searching"] = (
|
||||
"response.web_search_call.searching"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -862,13 +879,17 @@ class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed"
|
||||
type: Literal["response.web_search_call.completed"] = (
|
||||
"response.web_search_call.completed"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel):
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_list_tools.in_progress"] = "response.mcp_list_tools.in_progress"
|
||||
type: Literal["response.mcp_list_tools.in_progress"] = (
|
||||
"response.mcp_list_tools.in_progress"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -880,7 +901,9 @@ class OpenAIResponseObjectStreamResponseMcpListToolsFailed(BaseModel):
|
|||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel):
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_list_tools.completed"] = "response.mcp_list_tools.completed"
|
||||
type: Literal["response.mcp_list_tools.completed"] = (
|
||||
"response.mcp_list_tools.completed"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -889,7 +912,9 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.arguments.delta"] = "response.mcp_call.arguments.delta"
|
||||
type: Literal["response.mcp_call.arguments.delta"] = (
|
||||
"response.mcp_call.arguments.delta"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -898,7 +923,9 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.arguments.done"] = "response.mcp_call.arguments.done"
|
||||
type: Literal["response.mcp_call.arguments.done"] = (
|
||||
"response.mcp_call.arguments.done"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -970,7 +997,9 @@ class OpenAIResponseContentPartReasoningText(BaseModel):
|
|||
|
||||
|
||||
OpenAIResponseContentPart = Annotated[
|
||||
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText,
|
||||
OpenAIResponseContentPartOutputText
|
||||
| OpenAIResponseContentPartRefusal
|
||||
| OpenAIResponseContentPartReasoningText,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
|
||||
|
|
@ -1089,7 +1118,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel):
|
|||
part: OpenAIResponseContentPartReasoningSummary
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added"
|
||||
type: Literal["response.reasoning_summary_part.added"] = (
|
||||
"response.reasoning_summary_part.added"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1109,7 +1140,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel):
|
|||
part: OpenAIResponseContentPartReasoningSummary
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done"
|
||||
type: Literal["response.reasoning_summary_part.done"] = (
|
||||
"response.reasoning_summary_part.done"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1129,7 +1162,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel):
|
|||
output_index: int
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta"
|
||||
type: Literal["response.reasoning_summary_text.delta"] = (
|
||||
"response.reasoning_summary_text.delta"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1149,7 +1184,9 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel):
|
|||
output_index: int
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done"
|
||||
type: Literal["response.reasoning_summary_text.done"] = (
|
||||
"response.reasoning_summary_text.done"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1211,7 +1248,9 @@ class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel):
|
|||
annotation_index: int
|
||||
annotation: OpenAIResponseAnnotations
|
||||
sequence_number: int
|
||||
type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
|
||||
type: Literal["response.output_text.annotation.added"] = (
|
||||
"response.output_text.annotation.added"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1227,7 +1266,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress"
|
||||
type: Literal["response.file_search_call.in_progress"] = (
|
||||
"response.file_search_call.in_progress"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1243,7 +1284,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching"
|
||||
type: Literal["response.file_search_call.searching"] = (
|
||||
"response.file_search_call.searching"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1259,7 +1302,9 @@ class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed"
|
||||
type: Literal["response.file_search_call.completed"] = (
|
||||
"response.file_search_call.completed"
|
||||
)
|
||||
|
||||
|
||||
OpenAIResponseObjectStream = Annotated[
|
||||
|
|
@ -1350,7 +1395,9 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|||
|
||||
def to_response_object(self) -> OpenAIResponseObject:
|
||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
|
||||
return OpenAIResponseObject(
|
||||
**{k: v for k, v in self.model_dump().items() if k != "input"}
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -1055,6 +1055,15 @@ class StreamingResponseOrchestrator:
|
|||
"""Process an MCP tool configuration and emit appropriate streaming events."""
|
||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||
|
||||
# Validate that Authorization header is not passed via headers dict
|
||||
if mcp_tool.headers:
|
||||
for key in mcp_tool.headers.keys():
|
||||
if key.lower() == "authorization":
|
||||
raise ValueError(
|
||||
"Authorization header cannot be passed via 'headers'. "
|
||||
"Please use the 'authorization' parameter instead."
|
||||
)
|
||||
|
||||
# Emit mcp_list_tools.in_progress
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue