diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index ff86e30e1..3a6735cbc 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -1810,7 +1810,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RegisterScoringFunctionRequestLoose' + $ref: '#/components/schemas/RegisterScoringFunctionRequest' required: true deprecated: true /v1/scoring-functions/{scoring_fn_id}: @@ -3300,7 +3300,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RegisterDatasetRequestLoose' + $ref: '#/components/schemas/RegisterDatasetRequest' required: true deprecated: true /v1beta/datasets/{dataset_id}: @@ -3557,7 +3557,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/BenchmarkConfig' + $ref: '#/components/schemas/RunEvalRequest' required: true /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}: get: @@ -10586,6 +10586,14 @@ components: - scores title: EvaluateResponse description: The response from an evaluation. + RunEvalRequest: + properties: + benchmark_config: + $ref: '#/components/schemas/BenchmarkConfig' + type: object + required: + - benchmark_config + title: RunEvalRequest Job: properties: job_id: @@ -11169,6 +11177,67 @@ components: - $ref: '#/components/schemas/CompletionInputType' title: CompletionInputType title: StringType | ... (9 variants) + RegisterScoringFunctionRequest: + properties: + scoring_fn_id: + type: string + title: Scoring Fn Id + description: + type: string + title: Description + return_type: + anyOf: + - $ref: '#/components/schemas/StringType' + title: StringType + - $ref: '#/components/schemas/NumberType' + title: NumberType + - $ref: '#/components/schemas/BooleanType' + title: BooleanType + - $ref: '#/components/schemas/ArrayType' + title: ArrayType + - $ref: '#/components/schemas/ObjectType' + title: ObjectType + - $ref: '#/components/schemas/JsonType' + title: JsonType + - $ref: '#/components/schemas/UnionType' + title: UnionType + - $ref: '#/components/schemas/ChatCompletionInputType' + title: ChatCompletionInputType + - $ref: '#/components/schemas/CompletionInputType' + title: CompletionInputType + title: StringType | ... (9 variants) + provider_scoring_fn_id: + anyOf: + - type: string + - type: 'null' + provider_id: + anyOf: + - type: string + - type: 'null' + params: + anyOf: + - oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + title: LLMAsJudgeScoringFnParams + - $ref: '#/components/schemas/RegexParserScoringFnParams' + title: RegexParserScoringFnParams + - $ref: '#/components/schemas/BasicScoringFnParams' + title: BasicScoringFnParams + discriminator: + propertyName: type + mapping: + basic: '#/components/schemas/BasicScoringFnParams' + llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' + regex_parser: '#/components/schemas/RegexParserScoringFnParams' + title: LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams + - type: 'null' + title: Params + type: object + required: + - scoring_fn_id + - description + - return_type + title: RegisterScoringFunctionRequest RegisterShieldRequest: properties: shield_id: @@ -11227,6 +11296,31 @@ components: - $ref: '#/components/schemas/RowsDataSource' title: RowsDataSource title: URIDataSource | RowsDataSource + RegisterDatasetRequest: + properties: + purpose: + $ref: '#/components/schemas/DatasetPurpose' + source: + anyOf: + - $ref: '#/components/schemas/URIDataSource' + title: URIDataSource + - $ref: '#/components/schemas/RowsDataSource' + title: RowsDataSource + title: URIDataSource | RowsDataSource + metadata: + anyOf: + - additionalProperties: true + type: object + - type: 'null' + dataset_id: + anyOf: + - type: string + - type: 'null' + type: object + required: + - purpose + - source + title: RegisterDatasetRequest RegisterBenchmarkRequest: properties: benchmark_id: @@ -11963,41 +12057,6 @@ components: required: - reasoning_tokens title: OutputTokensDetails - RegisterDatasetRequestLoose: - properties: - purpose: - title: Purpose - source: - title: Source - metadata: - title: Metadata - dataset_id: - title: Dataset Id - type: object - required: - - purpose - - source - title: RegisterDatasetRequestLoose - RegisterScoringFunctionRequestLoose: - properties: - scoring_fn_id: - title: Scoring Fn Id - description: - title: Description - return_type: - title: Return Type - provider_scoring_fn_id: - title: Provider Scoring Fn Id - provider_id: - title: Provider Id - params: - title: Params - type: object - required: - - scoring_fn_id - - description - - return_type - title: RegisterScoringFunctionRequestLoose SearchRankingOptions: properties: ranker: diff --git a/docs/docs/building_applications/tools.mdx b/docs/docs/building_applications/tools.mdx index 3b78ec57b..f7b913fef 100644 --- a/docs/docs/building_applications/tools.mdx +++ b/docs/docs/building_applications/tools.mdx @@ -104,23 +104,19 @@ client.toolgroups.register( ) ``` -Note that most of the more useful MCP servers need you to authenticate with them. Many of them use OAuth2.0 for authentication. You can provide authorization headers to send to the MCP server using the "Provider Data" abstraction provided by Llama Stack. When making an agent call, +Note that most of the more useful MCP servers need you to authenticate with them. Many of them use OAuth2.0 for authentication. You can provide the authorization token when creating the Agent: ```python agent = Agent( ..., - tools=["mcp::deepwiki"], - extra_headers={ - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": { - "http://mcp.deepwiki.com/sse": { - "Authorization": "Bearer ", - }, - }, - } - ), - }, + tools=[ + { + "type": "mcp", + "server_url": "https://mcp.deepwiki.com/sse", + "server_label": "mcp::deepwiki", + "authorization": "", # OAuth token (without "Bearer " prefix) + } + ], ) agent.create_turn(...) ``` diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 3bc06d7d7..0bade1866 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -193,7 +193,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RegisterScoringFunctionRequestLoose' + $ref: '#/components/schemas/RegisterScoringFunctionRequest' required: true deprecated: true /v1/scoring-functions/{scoring_fn_id}: @@ -549,7 +549,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RegisterDatasetRequestLoose' + $ref: '#/components/schemas/RegisterDatasetRequest' required: true deprecated: true /v1beta/datasets/{dataset_id}: @@ -7429,6 +7429,14 @@ components: - scores title: EvaluateResponse description: The response from an evaluation. + RunEvalRequest: + properties: + benchmark_config: + $ref: '#/components/schemas/BenchmarkConfig' + type: object + required: + - benchmark_config + title: RunEvalRequest Job: properties: job_id: @@ -8012,6 +8020,67 @@ components: - $ref: '#/components/schemas/CompletionInputType' title: CompletionInputType title: StringType | ... (9 variants) + RegisterScoringFunctionRequest: + properties: + scoring_fn_id: + type: string + title: Scoring Fn Id + description: + type: string + title: Description + return_type: + anyOf: + - $ref: '#/components/schemas/StringType' + title: StringType + - $ref: '#/components/schemas/NumberType' + title: NumberType + - $ref: '#/components/schemas/BooleanType' + title: BooleanType + - $ref: '#/components/schemas/ArrayType' + title: ArrayType + - $ref: '#/components/schemas/ObjectType' + title: ObjectType + - $ref: '#/components/schemas/JsonType' + title: JsonType + - $ref: '#/components/schemas/UnionType' + title: UnionType + - $ref: '#/components/schemas/ChatCompletionInputType' + title: ChatCompletionInputType + - $ref: '#/components/schemas/CompletionInputType' + title: CompletionInputType + title: StringType | ... (9 variants) + provider_scoring_fn_id: + anyOf: + - type: string + - type: 'null' + provider_id: + anyOf: + - type: string + - type: 'null' + params: + anyOf: + - oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + title: LLMAsJudgeScoringFnParams + - $ref: '#/components/schemas/RegexParserScoringFnParams' + title: RegexParserScoringFnParams + - $ref: '#/components/schemas/BasicScoringFnParams' + title: BasicScoringFnParams + discriminator: + propertyName: type + mapping: + basic: '#/components/schemas/BasicScoringFnParams' + llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' + regex_parser: '#/components/schemas/RegexParserScoringFnParams' + title: LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams + - type: 'null' + title: Params + type: object + required: + - scoring_fn_id + - description + - return_type + title: RegisterScoringFunctionRequest RegisterShieldRequest: properties: shield_id: @@ -8070,6 +8139,31 @@ components: - $ref: '#/components/schemas/RowsDataSource' title: RowsDataSource title: URIDataSource | RowsDataSource + RegisterDatasetRequest: + properties: + purpose: + $ref: '#/components/schemas/DatasetPurpose' + source: + anyOf: + - $ref: '#/components/schemas/URIDataSource' + title: URIDataSource + - $ref: '#/components/schemas/RowsDataSource' + title: RowsDataSource + title: URIDataSource | RowsDataSource + metadata: + anyOf: + - additionalProperties: true + type: object + - type: 'null' + dataset_id: + anyOf: + - type: string + - type: 'null' + type: object + required: + - purpose + - source + title: RegisterDatasetRequest RegisterBenchmarkRequest: properties: benchmark_id: @@ -8806,41 +8900,6 @@ components: required: - reasoning_tokens title: OutputTokensDetails - RegisterDatasetRequestLoose: - properties: - purpose: - title: Purpose - source: - title: Source - metadata: - title: Metadata - dataset_id: - title: Dataset Id - type: object - required: - - purpose - - source - title: RegisterDatasetRequestLoose - RegisterScoringFunctionRequestLoose: - properties: - scoring_fn_id: - title: Scoring Fn Id - description: - title: Description - return_type: - title: Return Type - provider_scoring_fn_id: - title: Provider Scoring Fn Id - provider_id: - title: Provider Id - params: - title: Params - type: object - required: - - scoring_fn_id - - description - - return_type - title: RegisterScoringFunctionRequestLoose SearchRankingOptions: properties: ranker: diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 2b36ebf47..4271989d6 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -300,7 +300,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/BenchmarkConfig' + $ref: '#/components/schemas/RunEvalRequest' required: true /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}: get: @@ -6711,6 +6711,14 @@ components: - scores title: EvaluateResponse description: The response from an evaluation. + RunEvalRequest: + properties: + benchmark_config: + $ref: '#/components/schemas/BenchmarkConfig' + type: object + required: + - benchmark_config + title: RunEvalRequest Job: properties: job_id: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index ff86e30e1..3a6735cbc 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -1810,7 +1810,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RegisterScoringFunctionRequestLoose' + $ref: '#/components/schemas/RegisterScoringFunctionRequest' required: true deprecated: true /v1/scoring-functions/{scoring_fn_id}: @@ -3300,7 +3300,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RegisterDatasetRequestLoose' + $ref: '#/components/schemas/RegisterDatasetRequest' required: true deprecated: true /v1beta/datasets/{dataset_id}: @@ -3557,7 +3557,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/BenchmarkConfig' + $ref: '#/components/schemas/RunEvalRequest' required: true /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}: get: @@ -10586,6 +10586,14 @@ components: - scores title: EvaluateResponse description: The response from an evaluation. + RunEvalRequest: + properties: + benchmark_config: + $ref: '#/components/schemas/BenchmarkConfig' + type: object + required: + - benchmark_config + title: RunEvalRequest Job: properties: job_id: @@ -11169,6 +11177,67 @@ components: - $ref: '#/components/schemas/CompletionInputType' title: CompletionInputType title: StringType | ... (9 variants) + RegisterScoringFunctionRequest: + properties: + scoring_fn_id: + type: string + title: Scoring Fn Id + description: + type: string + title: Description + return_type: + anyOf: + - $ref: '#/components/schemas/StringType' + title: StringType + - $ref: '#/components/schemas/NumberType' + title: NumberType + - $ref: '#/components/schemas/BooleanType' + title: BooleanType + - $ref: '#/components/schemas/ArrayType' + title: ArrayType + - $ref: '#/components/schemas/ObjectType' + title: ObjectType + - $ref: '#/components/schemas/JsonType' + title: JsonType + - $ref: '#/components/schemas/UnionType' + title: UnionType + - $ref: '#/components/schemas/ChatCompletionInputType' + title: ChatCompletionInputType + - $ref: '#/components/schemas/CompletionInputType' + title: CompletionInputType + title: StringType | ... (9 variants) + provider_scoring_fn_id: + anyOf: + - type: string + - type: 'null' + provider_id: + anyOf: + - type: string + - type: 'null' + params: + anyOf: + - oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + title: LLMAsJudgeScoringFnParams + - $ref: '#/components/schemas/RegexParserScoringFnParams' + title: RegexParserScoringFnParams + - $ref: '#/components/schemas/BasicScoringFnParams' + title: BasicScoringFnParams + discriminator: + propertyName: type + mapping: + basic: '#/components/schemas/BasicScoringFnParams' + llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' + regex_parser: '#/components/schemas/RegexParserScoringFnParams' + title: LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams + - type: 'null' + title: Params + type: object + required: + - scoring_fn_id + - description + - return_type + title: RegisterScoringFunctionRequest RegisterShieldRequest: properties: shield_id: @@ -11227,6 +11296,31 @@ components: - $ref: '#/components/schemas/RowsDataSource' title: RowsDataSource title: URIDataSource | RowsDataSource + RegisterDatasetRequest: + properties: + purpose: + $ref: '#/components/schemas/DatasetPurpose' + source: + anyOf: + - $ref: '#/components/schemas/URIDataSource' + title: URIDataSource + - $ref: '#/components/schemas/RowsDataSource' + title: RowsDataSource + title: URIDataSource | RowsDataSource + metadata: + anyOf: + - additionalProperties: true + type: object + - type: 'null' + dataset_id: + anyOf: + - type: string + - type: 'null' + type: object + required: + - purpose + - source + title: RegisterDatasetRequest RegisterBenchmarkRequest: properties: benchmark_id: @@ -11963,41 +12057,6 @@ components: required: - reasoning_tokens title: OutputTokensDetails - RegisterDatasetRequestLoose: - properties: - purpose: - title: Purpose - source: - title: Source - metadata: - title: Metadata - dataset_id: - title: Dataset Id - type: object - required: - - purpose - - source - title: RegisterDatasetRequestLoose - RegisterScoringFunctionRequestLoose: - properties: - scoring_fn_id: - title: Scoring Fn Id - description: - title: Description - return_type: - title: Return Type - provider_scoring_fn_id: - title: Provider Scoring Fn Id - provider_id: - title: Provider Id - params: - title: Params - type: object - required: - - scoring_fn_id - - description - - return_type - title: RegisterScoringFunctionRequestLoose SearchRankingOptions: properties: ranker: diff --git a/pyproject.toml b/pyproject.toml index bdf8309ad..eea515b09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ dependencies = [ "pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support. "pydantic>=2.11.9", "rich", - "starlette", "termcolor", "tiktoken", "pillow", @@ -50,7 +49,6 @@ dependencies = [ "aiosqlite>=0.21.0", # server - for metadata store "asyncpg", # for metadata store "sqlalchemy[asyncio]>=2.0.41", # server - for conversations - "pyyaml>=6.0.2", "starlette>=0.49.1", ] diff --git a/scripts/openapi_generator/endpoints.py b/scripts/openapi_generator/endpoints.py index 39086f47f..85203cb71 100644 --- a/scripts/openapi_generator/endpoints.py +++ b/scripts/openapi_generator/endpoints.py @@ -15,6 +15,7 @@ import typing from typing import Annotated, Any, get_args, get_origin from fastapi import FastAPI +from fastapi.params import Body as FastAPIBody from pydantic import Field, create_model from llama_stack.log import get_logger @@ -26,6 +27,8 @@ from .state import _extra_body_fields, register_dynamic_model logger = get_logger(name=__name__, category="core") +type QueryParameter = tuple[str, type, Any, bool] + def _to_pascal_case(segment: str) -> str: tokens = re.findall(r"[A-Za-z]+|\d+", segment) @@ -75,12 +78,12 @@ def _create_endpoint_with_request_model( return endpoint -def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_any: bool = False) -> dict[str, tuple]: +def _build_field_definitions(query_parameters: list[QueryParameter], use_any: bool = False) -> dict[str, tuple]: """Build field definitions for a Pydantic model from query parameters.""" from typing import Any field_definitions = {} - for param_name, param_type, default_value in query_parameters: + for param_name, param_type, default_value, _ in query_parameters: if use_any: field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value) continue @@ -108,10 +111,10 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_ field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value) # Ensure all parameters are included - expected_params = {name for name, _, _ in query_parameters} + expected_params = {name for name, _, _, _ in query_parameters} missing = expected_params - set(field_definitions.keys()) if missing: - for param_name, _, default_value in query_parameters: + for param_name, _, default_value, _ in query_parameters: if param_name in missing: field_definitions[param_name] = ( Any, @@ -126,7 +129,7 @@ def _create_dynamic_request_model( webmethod, method_name: str, http_method: str, - query_parameters: list[tuple[str, type, Any]], + query_parameters: list[QueryParameter], use_any: bool = False, variant_suffix: str | None = None, ) -> type | None: @@ -143,12 +146,12 @@ def _create_dynamic_request_model( def _build_signature_params( - query_parameters: list[tuple[str, type, Any]], + query_parameters: list[QueryParameter], ) -> tuple[list[inspect.Parameter], dict[str, type]]: """Build signature parameters and annotations from query parameters.""" signature_params = [] param_annotations = {} - for param_name, param_type, default_value in query_parameters: + for param_name, param_type, default_value, _ in query_parameters: param_annotations[param_name] = param_type signature_params.append( inspect.Parameter( @@ -219,6 +222,19 @@ def _is_extra_body_field(metadata_item: Any) -> bool: return isinstance(metadata_item, ExtraBodyField) +def _should_embed_parameter(param_type: Any) -> bool: + """Determine whether a parameter should be embedded (wrapped) in the request body.""" + if get_origin(param_type) is Annotated: + args = get_args(param_type) + metadata = args[1:] if len(args) > 1 else [] + for metadata_item in metadata: + if isinstance(metadata_item, FastAPIBody): + # FastAPI treats embed=None as False, so default to False when unset. + return bool(metadata_item.embed) + # Unannotated parameters default to embed=True through create_dynamic_typed_route. + return True + + def _is_async_iterator_type(type_obj: Any) -> bool: """Check if a type is AsyncIterator or AsyncIterable.""" from collections.abc import AsyncIterable, AsyncIterator @@ -282,7 +298,7 @@ def _find_models_for_endpoint( Returns: tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name) - where query_parameters is a list of (name, type, default_value) tuples + where query_parameters is a list of (name, type, default_value, should_embed) tuples and file_form_params is a list of inspect.Parameter objects for File()/Form() params and streaming_response_model is the model for streaming responses (AsyncIterator content) """ @@ -299,7 +315,7 @@ def _find_models_for_endpoint( # Find request model and collect all body parameters request_model = None - query_parameters = [] + query_parameters: list[QueryParameter] = [] file_form_params = [] path_params = set() extra_body_params = [] @@ -325,6 +341,7 @@ def _find_models_for_endpoint( # Check if it's a File() or Form() parameter - these need special handling param_type = param.annotation + param_should_embed = _should_embed_parameter(param_type) if _is_file_or_form_param(param_type): # File() and Form() parameters must be in the function signature directly # They cannot be part of a Pydantic model @@ -350,30 +367,14 @@ def _find_models_for_endpoint( # Store as extra body parameter - exclude from request model extra_body_params.append((param_name, base_type, extra_body_description)) continue + param_type = base_type # Check if it's a Pydantic model (for POST/PUT requests) if hasattr(param_type, "model_json_schema"): - # Collect all body parameters including Pydantic models - # We'll decide later whether to use a single model or create a combined one - query_parameters.append((param_name, param_type, param.default)) - elif get_origin(param_type) is Annotated: - # Handle Annotated types - get the base type - args = get_args(param_type) - if args and hasattr(args[0], "model_json_schema"): - # Collect Pydantic models from Annotated types - query_parameters.append((param_name, args[0], param.default)) - else: - # Regular annotated parameter (but not File/Form, already handled above) - query_parameters.append((param_name, param_type, param.default)) + query_parameters.append((param_name, param_type, param.default, param_should_embed)) else: - # This is likely a body parameter for POST/PUT or query parameter for GET - # Store the parameter info for later use - # Preserve inspect.Parameter.empty to distinguish "no default" from "default=None" - default_value = param.default - - # Extract the base type from union types (e.g., str | None -> str) - # Also make it safe for FastAPI to avoid forward reference issues - query_parameters.append((param_name, param_type, default_value)) + # Regular annotated parameter (but not File/Form, already handled above) + query_parameters.append((param_name, param_type, param.default, param_should_embed)) # Store extra body fields for later use in post-processing # We'll store them when the endpoint is created, as we need the full path @@ -385,8 +386,8 @@ def _find_models_for_endpoint( # Otherwise, we'll create a combined request model from all parameters # BUT: For GET requests, never create a request body - all parameters should be query parameters if is_post_put and len(query_parameters) == 1: - param_name, param_type, default_value = query_parameters[0] - if hasattr(param_type, "model_json_schema"): + param_name, param_type, default_value, should_embed = query_parameters[0] + if hasattr(param_type, "model_json_schema") and not should_embed: request_model = param_type query_parameters = [] # Clear query_parameters so we use the single model @@ -495,7 +496,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api): if file_form_params and is_post_put: signature_params = list(file_form_params) param_annotations = {param.name: param.annotation for param in file_form_params} - for param_name, param_type, default_value in query_parameters: + for param_name, param_type, default_value, _ in query_parameters: signature_params.append( inspect.Parameter( param_name, diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 649bddecb..97b044dbf 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -48,16 +48,10 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - # Phase 1: Support both old header-based auth AND new authorization parameter - # Get headers and auth from provider data (old approach) - provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri) + # Get other headers from provider data (but NOT authorization) + provider_headers = await self.get_headers_from_request(mcp_endpoint.uri) - # New authorization parameter takes precedence over provider data - final_authorization = authorization or provider_auth - - return await list_mcp_tools( - endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=final_authorization - ) + return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization) async def invoke_tool( self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None @@ -69,39 +63,38 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") - # Phase 1: Support both old header-based auth AND new authorization parameter - # Get headers and auth from provider data (old approach) - provider_headers, provider_auth = await self.get_headers_from_request(endpoint) - - # New authorization parameter takes precedence over provider data - final_authorization = authorization or provider_auth + # Get other headers from provider data (but NOT authorization) + provider_headers = await self.get_headers_from_request(endpoint) return await invoke_mcp_tool( endpoint=endpoint, tool_name=tool_name, kwargs=kwargs, headers=provider_headers, - authorization=final_authorization, + authorization=authorization, ) - async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]: + async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: """ - Extract headers and authorization from request provider data (Phase 1 backward compatibility). + Extract headers from request provider data, excluding authorization. - Phase 1: Temporarily allows Authorization to be passed via mcp_headers for backward compatibility. - Phase 2: Will enforce that Authorization should use the dedicated authorization parameter instead. + Authorization must be provided via the dedicated authorization parameter. + If Authorization is found in mcp_headers, raise an error to guide users to the correct approach. + + Args: + mcp_endpoint_uri: The MCP endpoint URI to match against provider data Returns: - Tuple of (headers_dict, authorization_token) - - headers_dict: All headers except Authorization - - authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None + dict[str, str]: Headers dictionary (without Authorization) + + Raises: + ValueError: If Authorization header is found in mcp_headers """ def canonicalize_uri(uri: str) -> str: return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}" headers = {} - authorization = None provider_data = self.get_request_provider_data() if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers: @@ -109,17 +102,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): continue - # Phase 1: Extract Authorization from mcp_headers for backward compatibility - # (Phase 2 will reject this and require the dedicated authorization parameter) + # Reject Authorization in mcp_headers - must use authorization parameter for key in values.keys(): if key.lower() == "authorization": - # Extract authorization token and strip "Bearer " prefix if present - auth_value = values[key] - if auth_value.startswith("Bearer "): - authorization = auth_value[7:] # Remove "Bearer " prefix - else: - authorization = auth_value - else: - headers[key] = values[key] + raise ValueError( + "Authorization cannot be provided via mcp_headers in provider_data. " + "Please use the dedicated 'authorization' parameter instead. " + "Example: tool_runtime.invoke_tool(..., authorization='your-token')" + ) + headers[key] = values[key] - return headers, authorization + return headers diff --git a/tests/integration/inference/test_tools_with_schemas.py b/tests/integration/inference/test_tools_with_schemas.py index 5b6e69ae3..ab033c381 100644 --- a/tests/integration/inference/test_tools_with_schemas.py +++ b/tests/integration/inference/test_tools_with_schemas.py @@ -9,8 +9,6 @@ Integration tests for inference/chat completion with JSON Schema-based tools. Tests that tools pass through correctly to various LLM providers. """ -import json - import pytest from llama_stack.core.library_client import LlamaStackAsLibraryClient @@ -193,22 +191,11 @@ class TestMCPToolsInChatCompletion: mcp_endpoint=dict(uri=uri), ) - # Use old header-based approach for Phase 1 (backward compatibility) - provider_data = { - "mcp_headers": { - uri: { - "Authorization": f"Bearer {AUTH_TOKEN}", - }, - }, - } - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - } - + # Use the dedicated authorization parameter # Get the tools from MCP tools_response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Convert to OpenAI format for inference diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py index 1b7f509d2..074a92afb 100644 --- a/tests/integration/tool_runtime/test_mcp.py +++ b/tests/integration/tool_runtime/test_mcp.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json - import pytest from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta @@ -37,32 +35,20 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server): mcp_endpoint=dict(uri=uri), ) - # Use old header-based approach for Phase 1 (backward compatibility) - provider_data = { - "mcp_headers": { - uri: { - "Authorization": f"Bearer {AUTH_TOKEN}", - }, - }, - } - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - } - - with pytest.raises(Exception, match="Unauthorized"): - llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) - - tools_list = llama_stack_client.tools.list( - toolgroup_id=test_toolgroup_id, - extra_headers=auth_headers, # Use old header-based approach + # Use the dedicated authorization parameter (no more provider_data headers) + # This tests direct tool_runtime.invoke_tool API calls + tools_list = llama_stack_client.tool_runtime.list_tools( + tool_group_id=test_toolgroup_id, + authorization=AUTH_TOKEN, # Use dedicated authorization parameter ) assert len(tools_list) == 2 assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"} + # Invoke tool with authorization parameter response = llama_stack_client.tool_runtime.invoke_tool( tool_name="greet_everyone", kwargs=dict(url="https://www.google.com"), - extra_headers=auth_headers, # Use old header-based approach + authorization=AUTH_TOKEN, # Use dedicated authorization parameter ) content = response.content assert len(content) == 1 diff --git a/tests/integration/tool_runtime/test_mcp_json_schema.py b/tests/integration/tool_runtime/test_mcp_json_schema.py index 719588c7f..6be71caaf 100644 --- a/tests/integration/tool_runtime/test_mcp_json_schema.py +++ b/tests/integration/tool_runtime/test_mcp_json_schema.py @@ -8,8 +8,6 @@ Tests $ref, $defs, and other JSON Schema features through MCP integration. """ -import json - import pytest from llama_stack.core.library_client import LlamaStackAsLibraryClient @@ -122,22 +120,11 @@ class TestMCPSchemaPreservation: mcp_endpoint=dict(uri=uri), ) - # Use old header-based approach for Phase 1 (backward compatibility) - provider_data = { - "mcp_headers": { - uri: { - "Authorization": f"Bearer {AUTH_TOKEN}", - }, - }, - } - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - } - + # Use the dedicated authorization parameter # List runtime tools response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) tools = response @@ -173,22 +160,11 @@ class TestMCPSchemaPreservation: mcp_endpoint=dict(uri=uri), ) - # Use old header-based approach for Phase 1 (backward compatibility) - provider_data = { - "mcp_headers": { - uri: { - "Authorization": f"Bearer {AUTH_TOKEN}", - }, - }, - } - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - } - + # Use the dedicated authorization parameter # List tools response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Find book_flight tool (which should have $ref/$defs) @@ -230,21 +206,10 @@ class TestMCPSchemaPreservation: mcp_endpoint=dict(uri=uri), ) - # Use old header-based approach for Phase 1 (backward compatibility) - provider_data = { - "mcp_headers": { - uri: { - "Authorization": f"Bearer {AUTH_TOKEN}", - }, - }, - } - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - } - + # Use the dedicated authorization parameter response = llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Find get_weather tool @@ -284,22 +249,10 @@ class TestMCPToolInvocation: mcp_endpoint=dict(uri=uri), ) - # Use old header-based approach for Phase 1 (backward compatibility) - provider_data = { - "mcp_headers": { - uri: { - "Authorization": f"Bearer {AUTH_TOKEN}", - }, - }, - } - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - } - - # List tools to populate the tool index + # Use the dedicated authorization parameter llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Invoke tool with complex nested data @@ -311,7 +264,7 @@ class TestMCPToolInvocation: "shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}}, } }, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Should succeed without schema validation errors @@ -337,29 +290,17 @@ class TestMCPToolInvocation: mcp_endpoint=dict(uri=uri), ) - # Use old header-based approach for Phase 1 (backward compatibility) - provider_data = { - "mcp_headers": { - uri: { - "Authorization": f"Bearer {AUTH_TOKEN}", - }, - }, - } - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - } - - # List tools to populate the tool index + # Use the dedicated authorization parameter llama_stack_client.tool_runtime.list_tools( tool_group_id=test_toolgroup_id, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) # Test with email format result_email = llama_stack_client.tool_runtime.invoke_tool( tool_name="flexible_contact", kwargs={"contact_info": "user@example.com"}, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) assert result_email.error_message is None @@ -368,7 +309,7 @@ class TestMCPToolInvocation: result_phone = llama_stack_client.tool_runtime.invoke_tool( tool_name="flexible_contact", kwargs={"contact_info": "+15551234567"}, - extra_headers=auth_headers, + authorization=AUTH_TOKEN, ) assert result_phone.error_message is None @@ -400,21 +341,10 @@ class TestAgentWithMCPTools: mcp_endpoint=dict(uri=uri), ) - # Use old header-based approach for Phase 1 (backward compatibility) - provider_data = { - "mcp_headers": { - uri: { - "Authorization": f"Bearer {AUTH_TOKEN}", - }, - }, - } - auth_headers = { - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - } - - tools_list = llama_stack_client.tools.list( - toolgroup_id=test_toolgroup_id, - extra_headers=auth_headers, + # Use the dedicated authorization parameter + tools_list = llama_stack_client.tool_runtime.list_tools( + tool_group_id=test_toolgroup_id, + authorization=AUTH_TOKEN, ) tool_defs = [ { diff --git a/uv.lock b/uv.lock index a343eb5d8..8c648c362 100644 --- a/uv.lock +++ b/uv.lock @@ -2165,10 +2165,8 @@ requires-dist = [ { name = "python-dotenv" }, { name = "python-multipart", specifier = ">=0.0.20" }, { name = "pyyaml", specifier = ">=6.0" }, - { name = "pyyaml", specifier = ">=6.0.2" }, { name = "rich" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, - { name = "starlette" }, { name = "starlette", specifier = ">=0.49.1" }, { name = "termcolor" }, { name = "tiktoken" }, @@ -4656,6 +4654,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/fa/3234f913fe9a6525a7b97c6dad1f51e72b917e6872e051a5e2ffd8b16fbb/ruamel.yaml.clib-0.2.14-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:70eda7703b8126f5e52fcf276e6c0f40b0d314674f896fc58c47b0aef2b9ae83", size = 137970, upload-time = "2025-09-22T19:51:09.472Z" }, { url = "https://files.pythonhosted.org/packages/ef/ec/4edbf17ac2c87fa0845dd366ef8d5852b96eb58fcd65fc1ecf5fe27b4641/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a0cb71ccc6ef9ce36eecb6272c81afdc2f565950cdcec33ae8e6cd8f7fc86f27", size = 739639, upload-time = "2025-09-22T19:51:10.566Z" }, { url = "https://files.pythonhosted.org/packages/15/18/b0e1fafe59051de9e79cdd431863b03593ecfa8341c110affad7c8121efc/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e7cb9ad1d525d40f7d87b6df7c0ff916a66bc52cb61b66ac1b2a16d0c1b07640", size = 764456, upload-time = "2025-09-22T19:51:11.736Z" }, + { url = "https://files.pythonhosted.org/packages/e7/cd/150fdb96b8fab27fe08d8a59fe67554568727981806e6bc2677a16081ec7/ruamel_yaml_clib-0.2.14-cp314-cp314-win32.whl", hash = "sha256:9b4104bf43ca0cd4e6f738cb86326a3b2f6eef00f417bd1e7efb7bdffe74c539", size = 102394, upload-time = "2025-11-14T21:57:36.703Z" }, + { url = "https://files.pythonhosted.org/packages/bd/e6/a3fa40084558c7e1dc9546385f22a93949c890a8b2e445b2ba43935f51da/ruamel_yaml_clib-0.2.14-cp314-cp314-win_amd64.whl", hash = "sha256:13997d7d354a9890ea1ec5937a219817464e5cc344805b37671562a401ca3008", size = 122673, upload-time = "2025-11-14T21:57:38.177Z" }, ] [[package]]