precommit

This commit is contained in:
Omar Abdelwahab 2025-11-07 12:14:42 -08:00
parent 445135b8cc
commit ccb870c8fb
6 changed files with 49 additions and 112 deletions

View file

@ -7131,11 +7131,13 @@ components:
- type: array - type: array
- type: object - type: object
description: >- description: >-
(Optional) HTTP headers to include when connecting to the server (Optional) HTTP headers to include when connecting to the server (cannot
contain Authorization)
authorization: authorization:
type: string type: string
description: >- description: >-
(Optional) OAuth access token for authenticating with the MCP server (Optional) OAuth access token for authenticating with the MCP server (excluded
from responses)
require_approval: require_approval:
oneOf: oneOf:
- type: string - type: string

View file

@ -6415,11 +6415,13 @@ components:
- type: array - type: array
- type: object - type: object
description: >- description: >-
(Optional) HTTP headers to include when connecting to the server (Optional) HTTP headers to include when connecting to the server (cannot
contain Authorization)
authorization: authorization:
type: string type: string
description: >- description: >-
(Optional) OAuth access token for authenticating with the MCP server (Optional) OAuth access token for authenticating with the MCP server (excluded
from responses)
require_approval: require_approval:
oneOf: oneOf:
- type: string - type: string

View file

@ -7131,11 +7131,13 @@ components:
- type: array - type: array
- type: object - type: object
description: >- description: >-
(Optional) HTTP headers to include when connecting to the server (Optional) HTTP headers to include when connecting to the server (cannot
contain Authorization)
authorization: authorization:
type: string type: string
description: >- description: >-
(Optional) OAuth access token for authenticating with the MCP server (Optional) OAuth access token for authenticating with the MCP server (excluded
from responses)
require_approval: require_approval:
oneOf: oneOf:
- type: string - type: string

View file

@ -7,12 +7,12 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
from llama_stack.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing_extensions import TypedDict from typing_extensions import TypedDict
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
from llama_stack.schema_utils import json_schema_type, register_schema
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably # NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
# take their YAML and generate this file automatically. Their YAML is available. # take their YAML and generate this file automatically. Their YAML is available.
@ -89,9 +89,7 @@ OpenAIResponseInputMessageContent = Annotated[
| OpenAIResponseInputMessageContentFile, | OpenAIResponseInputMessageContentFile,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema( register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent"
)
@json_schema_type @json_schema_type
@ -193,9 +191,7 @@ OpenAIResponseOutputMessageContent = Annotated[
OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal, OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema( register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent"
)
@json_schema_type @json_schema_type
@ -207,17 +203,8 @@ class OpenAIResponseMessage(BaseModel):
scenarios. scenarios.
""" """
content: ( content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
str role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
| Sequence[OpenAIResponseInputMessageContent]
| Sequence[OpenAIResponseOutputMessageContent]
)
role: (
Literal["system"]
| Literal["developer"]
| Literal["user"]
| Literal["assistant"]
)
type: Literal["message"] = "message" type: Literal["message"] = "message"
# The fields below are not used in all scenarios, but are required in others. # The fields below are not used in all scenarios, but are required in others.
@ -271,9 +258,7 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
queries: Sequence[str] queries: Sequence[str]
status: str status: str
type: Literal["file_search_call"] = "file_search_call" type: Literal["file_search_call"] = "file_search_call"
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = ( results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
None
)
@json_schema_type @json_schema_type
@ -434,15 +419,11 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
""" """
# Must match values of WebSearchToolTypes above # Must match values of WebSearchToolTypes above
type: ( type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
Literal["web_search"] "web_search"
| Literal["web_search_preview"]
| Literal["web_search_preview_2025_03_11"]
) = "web_search"
# TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(
default="medium", pattern="^low|medium|high$"
) )
# TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location # TODO: add user_location
@ -632,9 +613,7 @@ class OpenAIResponseObject(BaseModel):
temperature: float | None = None temperature: float | None = None
# Default to text format to avoid breaking the loading of old responses # Default to text format to avoid breaking the loading of old responses
# before the field was added. New responses will have this set always. # before the field was added. New responses will have this set always.
text: OpenAIResponseText = OpenAIResponseText( text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
format=OpenAIResponseTextFormat(type="text")
)
top_p: float | None = None top_p: float | None = None
tools: Sequence[OpenAIResponseTool] | None = None tools: Sequence[OpenAIResponseTool] | None = None
truncation: str | None = None truncation: str | None = None
@ -813,9 +792,7 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.function_call_arguments.delta"] = ( type: Literal["response.function_call_arguments.delta"] = "response.function_call_arguments.delta"
"response.function_call_arguments.delta"
)
@json_schema_type @json_schema_type
@ -833,9 +810,7 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.function_call_arguments.done"] = ( type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
"response.function_call_arguments.done"
)
@json_schema_type @json_schema_type
@ -851,9 +826,7 @@ class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.web_search_call.in_progress"] = ( type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress"
"response.web_search_call.in_progress"
)
@json_schema_type @json_schema_type
@ -861,9 +834,7 @@ class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.web_search_call.searching"] = ( type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching"
"response.web_search_call.searching"
)
@json_schema_type @json_schema_type
@ -879,17 +850,13 @@ class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.web_search_call.completed"] = ( type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed"
"response.web_search_call.completed"
)
@json_schema_type @json_schema_type
class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel): class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel):
sequence_number: int sequence_number: int
type: Literal["response.mcp_list_tools.in_progress"] = ( type: Literal["response.mcp_list_tools.in_progress"] = "response.mcp_list_tools.in_progress"
"response.mcp_list_tools.in_progress"
)
@json_schema_type @json_schema_type
@ -901,9 +868,7 @@ class OpenAIResponseObjectStreamResponseMcpListToolsFailed(BaseModel):
@json_schema_type @json_schema_type
class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel): class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel):
sequence_number: int sequence_number: int
type: Literal["response.mcp_list_tools.completed"] = ( type: Literal["response.mcp_list_tools.completed"] = "response.mcp_list_tools.completed"
"response.mcp_list_tools.completed"
)
@json_schema_type @json_schema_type
@ -912,9 +877,7 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.mcp_call.arguments.delta"] = ( type: Literal["response.mcp_call.arguments.delta"] = "response.mcp_call.arguments.delta"
"response.mcp_call.arguments.delta"
)
@json_schema_type @json_schema_type
@ -923,9 +886,7 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.mcp_call.arguments.done"] = ( type: Literal["response.mcp_call.arguments.done"] = "response.mcp_call.arguments.done"
"response.mcp_call.arguments.done"
)
@json_schema_type @json_schema_type
@ -997,9 +958,7 @@ class OpenAIResponseContentPartReasoningText(BaseModel):
OpenAIResponseContentPart = Annotated[ OpenAIResponseContentPart = Annotated[
OpenAIResponseContentPartOutputText OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText,
| OpenAIResponseContentPartRefusal
| OpenAIResponseContentPartReasoningText,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart") register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
@ -1118,9 +1077,7 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel):
part: OpenAIResponseContentPartReasoningSummary part: OpenAIResponseContentPartReasoningSummary
sequence_number: int sequence_number: int
summary_index: int summary_index: int
type: Literal["response.reasoning_summary_part.added"] = ( type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added"
"response.reasoning_summary_part.added"
)
@json_schema_type @json_schema_type
@ -1140,9 +1097,7 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel):
part: OpenAIResponseContentPartReasoningSummary part: OpenAIResponseContentPartReasoningSummary
sequence_number: int sequence_number: int
summary_index: int summary_index: int
type: Literal["response.reasoning_summary_part.done"] = ( type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done"
"response.reasoning_summary_part.done"
)
@json_schema_type @json_schema_type
@ -1162,9 +1117,7 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel):
output_index: int output_index: int
sequence_number: int sequence_number: int
summary_index: int summary_index: int
type: Literal["response.reasoning_summary_text.delta"] = ( type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta"
"response.reasoning_summary_text.delta"
)
@json_schema_type @json_schema_type
@ -1184,9 +1137,7 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel):
output_index: int output_index: int
sequence_number: int sequence_number: int
summary_index: int summary_index: int
type: Literal["response.reasoning_summary_text.done"] = ( type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done"
"response.reasoning_summary_text.done"
)
@json_schema_type @json_schema_type
@ -1248,9 +1199,7 @@ class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel):
annotation_index: int annotation_index: int
annotation: OpenAIResponseAnnotations annotation: OpenAIResponseAnnotations
sequence_number: int sequence_number: int
type: Literal["response.output_text.annotation.added"] = ( type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
"response.output_text.annotation.added"
)
@json_schema_type @json_schema_type
@ -1266,9 +1215,7 @@ class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.file_search_call.in_progress"] = ( type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress"
"response.file_search_call.in_progress"
)
@json_schema_type @json_schema_type
@ -1284,9 +1231,7 @@ class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.file_search_call.searching"] = ( type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching"
"response.file_search_call.searching"
)
@json_schema_type @json_schema_type
@ -1302,9 +1247,7 @@ class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel):
item_id: str item_id: str
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.file_search_call.completed"] = ( type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed"
"response.file_search_call.completed"
)
OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStream = Annotated[
@ -1395,9 +1338,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
def to_response_object(self) -> OpenAIResponseObject: def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field.""" """Convert to OpenAIResponseObject by excluding input field."""
return OpenAIResponseObject( return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
**{k: v for k, v in self.model_dump().items() if k != "input"}
)
@json_schema_type @json_schema_type

View file

@ -25,9 +25,7 @@ from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools") logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl( class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config self.config = config
@ -47,13 +45,9 @@ class ModelContextProtocolToolRuntimeImpl(
if mcp_endpoint is None: if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_endpoint is required")
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri) headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools( return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
)
async def invoke_tool( async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
self, tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None: if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata") raise ValueError(f"Tool {tool_name} does not have metadata")
@ -70,9 +64,7 @@ class ModelContextProtocolToolRuntimeImpl(
authorization=authorization, authorization=authorization,
) )
async def get_headers_from_request( async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
self, mcp_endpoint_uri: str
) -> tuple[dict[str, str], str | None]:
""" """
Extract headers and authorization from request provider data. Extract headers and authorization from request provider data.

View file

@ -111,9 +111,7 @@ def test_mcp_authorization_error_when_header_provided(compat_client, text_model_
) )
# Create response - should raise ValueError for security reasons # Create response - should raise ValueError for security reasons
with pytest.raises( with pytest.raises(ValueError, match="Authorization header cannot be passed via 'headers'"):
ValueError, match="Authorization header cannot be passed via 'headers'"
):
compat_client.responses.create( compat_client.responses.create(
model=text_model_id, model=text_model_id,
input="What is the boiling point of myawesomeliquid?", input="What is the boiling point of myawesomeliquid?",