mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
precommit
This commit is contained in:
parent
445135b8cc
commit
ccb870c8fb
6 changed files with 49 additions and 112 deletions
|
|
@ -7131,11 +7131,13 @@ components:
|
|||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
(Optional) HTTP headers to include when connecting to the server
|
||||
(Optional) HTTP headers to include when connecting to the server (cannot
|
||||
contain Authorization)
|
||||
authorization:
|
||||
type: string
|
||||
description: >-
|
||||
(Optional) OAuth access token for authenticating with the MCP server
|
||||
(Optional) OAuth access token for authenticating with the MCP server (excluded
|
||||
from responses)
|
||||
require_approval:
|
||||
oneOf:
|
||||
- type: string
|
||||
|
|
|
|||
6
docs/static/llama-stack-spec.yaml
vendored
6
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -6415,11 +6415,13 @@ components:
|
|||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
(Optional) HTTP headers to include when connecting to the server
|
||||
(Optional) HTTP headers to include when connecting to the server (cannot
|
||||
contain Authorization)
|
||||
authorization:
|
||||
type: string
|
||||
description: >-
|
||||
(Optional) OAuth access token for authenticating with the MCP server
|
||||
(Optional) OAuth access token for authenticating with the MCP server (excluded
|
||||
from responses)
|
||||
require_approval:
|
||||
oneOf:
|
||||
- type: string
|
||||
|
|
|
|||
6
docs/static/stainless-llama-stack-spec.yaml
vendored
6
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -7131,11 +7131,13 @@ components:
|
|||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
(Optional) HTTP headers to include when connecting to the server
|
||||
(Optional) HTTP headers to include when connecting to the server (cannot
|
||||
contain Authorization)
|
||||
authorization:
|
||||
type: string
|
||||
description: >-
|
||||
(Optional) OAuth access token for authenticating with the MCP server
|
||||
(Optional) OAuth access token for authenticating with the MCP server (excluded
|
||||
from responses)
|
||||
require_approval:
|
||||
oneOf:
|
||||
- type: string
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
||||
# take their YAML and generate this file automatically. Their YAML is available.
|
||||
|
||||
|
|
@ -89,9 +89,7 @@ OpenAIResponseInputMessageContent = Annotated[
|
|||
| OpenAIResponseInputMessageContentFile,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(
|
||||
OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent"
|
||||
)
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -193,9 +191,7 @@ OpenAIResponseOutputMessageContent = Annotated[
|
|||
OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(
|
||||
OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent"
|
||||
)
|
||||
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -207,17 +203,8 @@ class OpenAIResponseMessage(BaseModel):
|
|||
scenarios.
|
||||
"""
|
||||
|
||||
content: (
|
||||
str
|
||||
| Sequence[OpenAIResponseInputMessageContent]
|
||||
| Sequence[OpenAIResponseOutputMessageContent]
|
||||
)
|
||||
role: (
|
||||
Literal["system"]
|
||||
| Literal["developer"]
|
||||
| Literal["user"]
|
||||
| Literal["assistant"]
|
||||
)
|
||||
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] = "message"
|
||||
|
||||
# The fields below are not used in all scenarios, but are required in others.
|
||||
|
|
@ -271,9 +258,7 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
|||
queries: Sequence[str]
|
||||
status: str
|
||||
type: Literal["file_search_call"] = "file_search_call"
|
||||
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = (
|
||||
None
|
||||
)
|
||||
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -434,15 +419,11 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
|
|||
"""
|
||||
|
||||
# Must match values of WebSearchToolTypes above
|
||||
type: (
|
||||
Literal["web_search"]
|
||||
| Literal["web_search_preview"]
|
||||
| Literal["web_search_preview_2025_03_11"]
|
||||
) = "web_search"
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: str | None = Field(
|
||||
default="medium", pattern="^low|medium|high$"
|
||||
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
|
||||
"web_search"
|
||||
)
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||
# TODO: add user_location
|
||||
|
||||
|
||||
|
|
@ -632,9 +613,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
temperature: float | None = None
|
||||
# Default to text format to avoid breaking the loading of old responses
|
||||
# before the field was added. New responses will have this set always.
|
||||
text: OpenAIResponseText = OpenAIResponseText(
|
||||
format=OpenAIResponseTextFormat(type="text")
|
||||
)
|
||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
top_p: float | None = None
|
||||
tools: Sequence[OpenAIResponseTool] | None = None
|
||||
truncation: str | None = None
|
||||
|
|
@ -813,9 +792,7 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.function_call_arguments.delta"] = (
|
||||
"response.function_call_arguments.delta"
|
||||
)
|
||||
type: Literal["response.function_call_arguments.delta"] = "response.function_call_arguments.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -833,9 +810,7 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.function_call_arguments.done"] = (
|
||||
"response.function_call_arguments.done"
|
||||
)
|
||||
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -851,9 +826,7 @@ class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.web_search_call.in_progress"] = (
|
||||
"response.web_search_call.in_progress"
|
||||
)
|
||||
type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -861,9 +834,7 @@ class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.web_search_call.searching"] = (
|
||||
"response.web_search_call.searching"
|
||||
)
|
||||
type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -879,17 +850,13 @@ class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.web_search_call.completed"] = (
|
||||
"response.web_search_call.completed"
|
||||
)
|
||||
type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel):
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_list_tools.in_progress"] = (
|
||||
"response.mcp_list_tools.in_progress"
|
||||
)
|
||||
type: Literal["response.mcp_list_tools.in_progress"] = "response.mcp_list_tools.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -901,9 +868,7 @@ class OpenAIResponseObjectStreamResponseMcpListToolsFailed(BaseModel):
|
|||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel):
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_list_tools.completed"] = (
|
||||
"response.mcp_list_tools.completed"
|
||||
)
|
||||
type: Literal["response.mcp_list_tools.completed"] = "response.mcp_list_tools.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -912,9 +877,7 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.arguments.delta"] = (
|
||||
"response.mcp_call.arguments.delta"
|
||||
)
|
||||
type: Literal["response.mcp_call.arguments.delta"] = "response.mcp_call.arguments.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -923,9 +886,7 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.arguments.done"] = (
|
||||
"response.mcp_call.arguments.done"
|
||||
)
|
||||
type: Literal["response.mcp_call.arguments.done"] = "response.mcp_call.arguments.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -997,9 +958,7 @@ class OpenAIResponseContentPartReasoningText(BaseModel):
|
|||
|
||||
|
||||
OpenAIResponseContentPart = Annotated[
|
||||
OpenAIResponseContentPartOutputText
|
||||
| OpenAIResponseContentPartRefusal
|
||||
| OpenAIResponseContentPartReasoningText,
|
||||
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
|
||||
|
|
@ -1118,9 +1077,7 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel):
|
|||
part: OpenAIResponseContentPartReasoningSummary
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_part.added"] = (
|
||||
"response.reasoning_summary_part.added"
|
||||
)
|
||||
type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1140,9 +1097,7 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel):
|
|||
part: OpenAIResponseContentPartReasoningSummary
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_part.done"] = (
|
||||
"response.reasoning_summary_part.done"
|
||||
)
|
||||
type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1162,9 +1117,7 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel):
|
|||
output_index: int
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_text.delta"] = (
|
||||
"response.reasoning_summary_text.delta"
|
||||
)
|
||||
type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1184,9 +1137,7 @@ class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel):
|
|||
output_index: int
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_text.done"] = (
|
||||
"response.reasoning_summary_text.done"
|
||||
)
|
||||
type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1248,9 +1199,7 @@ class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel):
|
|||
annotation_index: int
|
||||
annotation: OpenAIResponseAnnotations
|
||||
sequence_number: int
|
||||
type: Literal["response.output_text.annotation.added"] = (
|
||||
"response.output_text.annotation.added"
|
||||
)
|
||||
type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1266,9 +1215,7 @@ class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.in_progress"] = (
|
||||
"response.file_search_call.in_progress"
|
||||
)
|
||||
type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1284,9 +1231,7 @@ class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.searching"] = (
|
||||
"response.file_search_call.searching"
|
||||
)
|
||||
type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -1302,9 +1247,7 @@ class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel):
|
|||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.completed"] = (
|
||||
"response.file_search_call.completed"
|
||||
)
|
||||
type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed"
|
||||
|
||||
|
||||
OpenAIResponseObjectStream = Annotated[
|
||||
|
|
@ -1395,9 +1338,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|||
|
||||
def to_response_object(self) -> OpenAIResponseObject:
|
||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||
return OpenAIResponseObject(
|
||||
**{k: v for k, v in self.model_dump().items() if k != "input"}
|
||||
)
|
||||
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -25,9 +25,7 @@ from .config import MCPProviderConfig
|
|||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(
|
||||
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
|
||||
|
|
@ -47,13 +45,9 @@ class ModelContextProtocolToolRuntimeImpl(
|
|||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
return await list_mcp_tools(
|
||||
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
|
||||
)
|
||||
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
raise ValueError(f"Tool {tool_name} does not have metadata")
|
||||
|
|
@ -70,9 +64,7 @@ class ModelContextProtocolToolRuntimeImpl(
|
|||
authorization=authorization,
|
||||
)
|
||||
|
||||
async def get_headers_from_request(
|
||||
self, mcp_endpoint_uri: str
|
||||
) -> tuple[dict[str, str], str | None]:
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
|
||||
"""
|
||||
Extract headers and authorization from request provider data.
|
||||
|
||||
|
|
|
|||
|
|
@ -111,9 +111,7 @@ def test_mcp_authorization_error_when_header_provided(compat_client, text_model_
|
|||
)
|
||||
|
||||
# Create response - should raise ValueError for security reasons
|
||||
with pytest.raises(
|
||||
ValueError, match="Authorization header cannot be passed via 'headers'"
|
||||
):
|
||||
with pytest.raises(ValueError, match="Authorization header cannot be passed via 'headers'"):
|
||||
compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What is the boiling point of myawesomeliquid?",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue