Merge branch 'main' into fix-type-hints-syntax

This commit is contained in:
Ashwin Bharambe 2025-11-17 12:27:34 -08:00 committed by GitHub
commit 47027c65a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 394 additions and 321 deletions

View file

@ -1810,7 +1810,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/RegisterScoringFunctionRequestLoose' $ref: '#/components/schemas/RegisterScoringFunctionRequest'
required: true required: true
deprecated: true deprecated: true
/v1/scoring-functions/{scoring_fn_id}: /v1/scoring-functions/{scoring_fn_id}:
@ -3300,7 +3300,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/RegisterDatasetRequestLoose' $ref: '#/components/schemas/RegisterDatasetRequest'
required: true required: true
deprecated: true deprecated: true
/v1beta/datasets/{dataset_id}: /v1beta/datasets/{dataset_id}:
@ -3557,7 +3557,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/BenchmarkConfig' $ref: '#/components/schemas/RunEvalRequest'
required: true required: true
/v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}: /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}:
get: get:
@ -10586,6 +10586,14 @@ components:
- scores - scores
title: EvaluateResponse title: EvaluateResponse
description: The response from an evaluation. description: The response from an evaluation.
RunEvalRequest:
properties:
benchmark_config:
$ref: '#/components/schemas/BenchmarkConfig'
type: object
required:
- benchmark_config
title: RunEvalRequest
Job: Job:
properties: properties:
job_id: job_id:
@ -11169,6 +11177,67 @@ components:
- $ref: '#/components/schemas/CompletionInputType' - $ref: '#/components/schemas/CompletionInputType'
title: CompletionInputType title: CompletionInputType
title: StringType | ... (9 variants) title: StringType | ... (9 variants)
RegisterScoringFunctionRequest:
properties:
scoring_fn_id:
type: string
title: Scoring Fn Id
description:
type: string
title: Description
return_type:
anyOf:
- $ref: '#/components/schemas/StringType'
title: StringType
- $ref: '#/components/schemas/NumberType'
title: NumberType
- $ref: '#/components/schemas/BooleanType'
title: BooleanType
- $ref: '#/components/schemas/ArrayType'
title: ArrayType
- $ref: '#/components/schemas/ObjectType'
title: ObjectType
- $ref: '#/components/schemas/JsonType'
title: JsonType
- $ref: '#/components/schemas/UnionType'
title: UnionType
- $ref: '#/components/schemas/ChatCompletionInputType'
title: ChatCompletionInputType
- $ref: '#/components/schemas/CompletionInputType'
title: CompletionInputType
title: StringType | ... (9 variants)
provider_scoring_fn_id:
anyOf:
- type: string
- type: 'null'
provider_id:
anyOf:
- type: string
- type: 'null'
params:
anyOf:
- oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
title: LLMAsJudgeScoringFnParams
- $ref: '#/components/schemas/RegexParserScoringFnParams'
title: RegexParserScoringFnParams
- $ref: '#/components/schemas/BasicScoringFnParams'
title: BasicScoringFnParams
discriminator:
propertyName: type
mapping:
basic: '#/components/schemas/BasicScoringFnParams'
llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams'
regex_parser: '#/components/schemas/RegexParserScoringFnParams'
title: LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams
- type: 'null'
title: Params
type: object
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequest
RegisterShieldRequest: RegisterShieldRequest:
properties: properties:
shield_id: shield_id:
@ -11227,6 +11296,31 @@ components:
- $ref: '#/components/schemas/RowsDataSource' - $ref: '#/components/schemas/RowsDataSource'
title: RowsDataSource title: RowsDataSource
title: URIDataSource | RowsDataSource title: URIDataSource | RowsDataSource
RegisterDatasetRequest:
properties:
purpose:
$ref: '#/components/schemas/DatasetPurpose'
source:
anyOf:
- $ref: '#/components/schemas/URIDataSource'
title: URIDataSource
- $ref: '#/components/schemas/RowsDataSource'
title: RowsDataSource
title: URIDataSource | RowsDataSource
metadata:
anyOf:
- additionalProperties: true
type: object
- type: 'null'
dataset_id:
anyOf:
- type: string
- type: 'null'
type: object
required:
- purpose
- source
title: RegisterDatasetRequest
RegisterBenchmarkRequest: RegisterBenchmarkRequest:
properties: properties:
benchmark_id: benchmark_id:
@ -11963,41 +12057,6 @@ components:
required: required:
- reasoning_tokens - reasoning_tokens
title: OutputTokensDetails title: OutputTokensDetails
RegisterDatasetRequestLoose:
properties:
purpose:
title: Purpose
source:
title: Source
metadata:
title: Metadata
dataset_id:
title: Dataset Id
type: object
required:
- purpose
- source
title: RegisterDatasetRequestLoose
RegisterScoringFunctionRequestLoose:
properties:
scoring_fn_id:
title: Scoring Fn Id
description:
title: Description
return_type:
title: Return Type
provider_scoring_fn_id:
title: Provider Scoring Fn Id
provider_id:
title: Provider Id
params:
title: Params
type: object
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequestLoose
SearchRankingOptions: SearchRankingOptions:
properties: properties:
ranker: ranker:

View file

@ -104,23 +104,19 @@ client.toolgroups.register(
) )
``` ```
Note that most of the more useful MCP servers need you to authenticate with them. Many of them use OAuth2.0 for authentication. You can provide authorization headers to send to the MCP server using the "Provider Data" abstraction provided by Llama Stack. When making an agent call, Note that most of the more useful MCP servers need you to authenticate with them. Many of them use OAuth2.0 for authentication. You can provide the authorization token when creating the Agent:
```python ```python
agent = Agent( agent = Agent(
..., ...,
tools=["mcp::deepwiki"], tools=[
extra_headers={
"X-LlamaStack-Provider-Data": json.dumps(
{ {
"mcp_headers": { "type": "mcp",
"http://mcp.deepwiki.com/sse": { "server_url": "https://mcp.deepwiki.com/sse",
"Authorization": "Bearer <your_access_token>", "server_label": "mcp::deepwiki",
}, "authorization": "<your_access_token>", # OAuth token (without "Bearer " prefix)
},
} }
), ],
},
) )
agent.create_turn(...) agent.create_turn(...)
``` ```

View file

@ -193,7 +193,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/RegisterScoringFunctionRequestLoose' $ref: '#/components/schemas/RegisterScoringFunctionRequest'
required: true required: true
deprecated: true deprecated: true
/v1/scoring-functions/{scoring_fn_id}: /v1/scoring-functions/{scoring_fn_id}:
@ -549,7 +549,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/RegisterDatasetRequestLoose' $ref: '#/components/schemas/RegisterDatasetRequest'
required: true required: true
deprecated: true deprecated: true
/v1beta/datasets/{dataset_id}: /v1beta/datasets/{dataset_id}:
@ -7429,6 +7429,14 @@ components:
- scores - scores
title: EvaluateResponse title: EvaluateResponse
description: The response from an evaluation. description: The response from an evaluation.
RunEvalRequest:
properties:
benchmark_config:
$ref: '#/components/schemas/BenchmarkConfig'
type: object
required:
- benchmark_config
title: RunEvalRequest
Job: Job:
properties: properties:
job_id: job_id:
@ -8012,6 +8020,67 @@ components:
- $ref: '#/components/schemas/CompletionInputType' - $ref: '#/components/schemas/CompletionInputType'
title: CompletionInputType title: CompletionInputType
title: StringType | ... (9 variants) title: StringType | ... (9 variants)
RegisterScoringFunctionRequest:
properties:
scoring_fn_id:
type: string
title: Scoring Fn Id
description:
type: string
title: Description
return_type:
anyOf:
- $ref: '#/components/schemas/StringType'
title: StringType
- $ref: '#/components/schemas/NumberType'
title: NumberType
- $ref: '#/components/schemas/BooleanType'
title: BooleanType
- $ref: '#/components/schemas/ArrayType'
title: ArrayType
- $ref: '#/components/schemas/ObjectType'
title: ObjectType
- $ref: '#/components/schemas/JsonType'
title: JsonType
- $ref: '#/components/schemas/UnionType'
title: UnionType
- $ref: '#/components/schemas/ChatCompletionInputType'
title: ChatCompletionInputType
- $ref: '#/components/schemas/CompletionInputType'
title: CompletionInputType
title: StringType | ... (9 variants)
provider_scoring_fn_id:
anyOf:
- type: string
- type: 'null'
provider_id:
anyOf:
- type: string
- type: 'null'
params:
anyOf:
- oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
title: LLMAsJudgeScoringFnParams
- $ref: '#/components/schemas/RegexParserScoringFnParams'
title: RegexParserScoringFnParams
- $ref: '#/components/schemas/BasicScoringFnParams'
title: BasicScoringFnParams
discriminator:
propertyName: type
mapping:
basic: '#/components/schemas/BasicScoringFnParams'
llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams'
regex_parser: '#/components/schemas/RegexParserScoringFnParams'
title: LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams
- type: 'null'
title: Params
type: object
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequest
RegisterShieldRequest: RegisterShieldRequest:
properties: properties:
shield_id: shield_id:
@ -8070,6 +8139,31 @@ components:
- $ref: '#/components/schemas/RowsDataSource' - $ref: '#/components/schemas/RowsDataSource'
title: RowsDataSource title: RowsDataSource
title: URIDataSource | RowsDataSource title: URIDataSource | RowsDataSource
RegisterDatasetRequest:
properties:
purpose:
$ref: '#/components/schemas/DatasetPurpose'
source:
anyOf:
- $ref: '#/components/schemas/URIDataSource'
title: URIDataSource
- $ref: '#/components/schemas/RowsDataSource'
title: RowsDataSource
title: URIDataSource | RowsDataSource
metadata:
anyOf:
- additionalProperties: true
type: object
- type: 'null'
dataset_id:
anyOf:
- type: string
- type: 'null'
type: object
required:
- purpose
- source
title: RegisterDatasetRequest
RegisterBenchmarkRequest: RegisterBenchmarkRequest:
properties: properties:
benchmark_id: benchmark_id:
@ -8806,41 +8900,6 @@ components:
required: required:
- reasoning_tokens - reasoning_tokens
title: OutputTokensDetails title: OutputTokensDetails
RegisterDatasetRequestLoose:
properties:
purpose:
title: Purpose
source:
title: Source
metadata:
title: Metadata
dataset_id:
title: Dataset Id
type: object
required:
- purpose
- source
title: RegisterDatasetRequestLoose
RegisterScoringFunctionRequestLoose:
properties:
scoring_fn_id:
title: Scoring Fn Id
description:
title: Description
return_type:
title: Return Type
provider_scoring_fn_id:
title: Provider Scoring Fn Id
provider_id:
title: Provider Id
params:
title: Params
type: object
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequestLoose
SearchRankingOptions: SearchRankingOptions:
properties: properties:
ranker: ranker:

View file

@ -300,7 +300,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/BenchmarkConfig' $ref: '#/components/schemas/RunEvalRequest'
required: true required: true
/v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}: /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}:
get: get:
@ -6711,6 +6711,14 @@ components:
- scores - scores
title: EvaluateResponse title: EvaluateResponse
description: The response from an evaluation. description: The response from an evaluation.
RunEvalRequest:
properties:
benchmark_config:
$ref: '#/components/schemas/BenchmarkConfig'
type: object
required:
- benchmark_config
title: RunEvalRequest
Job: Job:
properties: properties:
job_id: job_id:

View file

@ -1810,7 +1810,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/RegisterScoringFunctionRequestLoose' $ref: '#/components/schemas/RegisterScoringFunctionRequest'
required: true required: true
deprecated: true deprecated: true
/v1/scoring-functions/{scoring_fn_id}: /v1/scoring-functions/{scoring_fn_id}:
@ -3300,7 +3300,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/RegisterDatasetRequestLoose' $ref: '#/components/schemas/RegisterDatasetRequest'
required: true required: true
deprecated: true deprecated: true
/v1beta/datasets/{dataset_id}: /v1beta/datasets/{dataset_id}:
@ -3557,7 +3557,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/BenchmarkConfig' $ref: '#/components/schemas/RunEvalRequest'
required: true required: true
/v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}: /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}:
get: get:
@ -10586,6 +10586,14 @@ components:
- scores - scores
title: EvaluateResponse title: EvaluateResponse
description: The response from an evaluation. description: The response from an evaluation.
RunEvalRequest:
properties:
benchmark_config:
$ref: '#/components/schemas/BenchmarkConfig'
type: object
required:
- benchmark_config
title: RunEvalRequest
Job: Job:
properties: properties:
job_id: job_id:
@ -11169,6 +11177,67 @@ components:
- $ref: '#/components/schemas/CompletionInputType' - $ref: '#/components/schemas/CompletionInputType'
title: CompletionInputType title: CompletionInputType
title: StringType | ... (9 variants) title: StringType | ... (9 variants)
RegisterScoringFunctionRequest:
properties:
scoring_fn_id:
type: string
title: Scoring Fn Id
description:
type: string
title: Description
return_type:
anyOf:
- $ref: '#/components/schemas/StringType'
title: StringType
- $ref: '#/components/schemas/NumberType'
title: NumberType
- $ref: '#/components/schemas/BooleanType'
title: BooleanType
- $ref: '#/components/schemas/ArrayType'
title: ArrayType
- $ref: '#/components/schemas/ObjectType'
title: ObjectType
- $ref: '#/components/schemas/JsonType'
title: JsonType
- $ref: '#/components/schemas/UnionType'
title: UnionType
- $ref: '#/components/schemas/ChatCompletionInputType'
title: ChatCompletionInputType
- $ref: '#/components/schemas/CompletionInputType'
title: CompletionInputType
title: StringType | ... (9 variants)
provider_scoring_fn_id:
anyOf:
- type: string
- type: 'null'
provider_id:
anyOf:
- type: string
- type: 'null'
params:
anyOf:
- oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
title: LLMAsJudgeScoringFnParams
- $ref: '#/components/schemas/RegexParserScoringFnParams'
title: RegexParserScoringFnParams
- $ref: '#/components/schemas/BasicScoringFnParams'
title: BasicScoringFnParams
discriminator:
propertyName: type
mapping:
basic: '#/components/schemas/BasicScoringFnParams'
llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams'
regex_parser: '#/components/schemas/RegexParserScoringFnParams'
title: LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams
- type: 'null'
title: Params
type: object
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequest
RegisterShieldRequest: RegisterShieldRequest:
properties: properties:
shield_id: shield_id:
@ -11227,6 +11296,31 @@ components:
- $ref: '#/components/schemas/RowsDataSource' - $ref: '#/components/schemas/RowsDataSource'
title: RowsDataSource title: RowsDataSource
title: URIDataSource | RowsDataSource title: URIDataSource | RowsDataSource
RegisterDatasetRequest:
properties:
purpose:
$ref: '#/components/schemas/DatasetPurpose'
source:
anyOf:
- $ref: '#/components/schemas/URIDataSource'
title: URIDataSource
- $ref: '#/components/schemas/RowsDataSource'
title: RowsDataSource
title: URIDataSource | RowsDataSource
metadata:
anyOf:
- additionalProperties: true
type: object
- type: 'null'
dataset_id:
anyOf:
- type: string
- type: 'null'
type: object
required:
- purpose
- source
title: RegisterDatasetRequest
RegisterBenchmarkRequest: RegisterBenchmarkRequest:
properties: properties:
benchmark_id: benchmark_id:
@ -11963,41 +12057,6 @@ components:
required: required:
- reasoning_tokens - reasoning_tokens
title: OutputTokensDetails title: OutputTokensDetails
RegisterDatasetRequestLoose:
properties:
purpose:
title: Purpose
source:
title: Source
metadata:
title: Metadata
dataset_id:
title: Dataset Id
type: object
required:
- purpose
- source
title: RegisterDatasetRequestLoose
RegisterScoringFunctionRequestLoose:
properties:
scoring_fn_id:
title: Scoring Fn Id
description:
title: Description
return_type:
title: Return Type
provider_scoring_fn_id:
title: Provider Scoring Fn Id
provider_id:
title: Provider Id
params:
title: Params
type: object
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequestLoose
SearchRankingOptions: SearchRankingOptions:
properties: properties:
ranker: ranker:

View file

@ -38,7 +38,6 @@ dependencies = [
"pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support. "pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
"pydantic>=2.11.9", "pydantic>=2.11.9",
"rich", "rich",
"starlette",
"termcolor", "termcolor",
"tiktoken", "tiktoken",
"pillow", "pillow",
@ -50,7 +49,6 @@ dependencies = [
"aiosqlite>=0.21.0", # server - for metadata store "aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store "asyncpg", # for metadata store
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations "sqlalchemy[asyncio]>=2.0.41", # server - for conversations
"pyyaml>=6.0.2",
"starlette>=0.49.1", "starlette>=0.49.1",
] ]

View file

@ -15,6 +15,7 @@ import typing
from typing import Annotated, Any, get_args, get_origin from typing import Annotated, Any, get_args, get_origin
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.params import Body as FastAPIBody
from pydantic import Field, create_model from pydantic import Field, create_model
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -26,6 +27,8 @@ from .state import _extra_body_fields, register_dynamic_model
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
type QueryParameter = tuple[str, type, Any, bool]
def _to_pascal_case(segment: str) -> str: def _to_pascal_case(segment: str) -> str:
tokens = re.findall(r"[A-Za-z]+|\d+", segment) tokens = re.findall(r"[A-Za-z]+|\d+", segment)
@ -75,12 +78,12 @@ def _create_endpoint_with_request_model(
return endpoint return endpoint
def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_any: bool = False) -> dict[str, tuple]: def _build_field_definitions(query_parameters: list[QueryParameter], use_any: bool = False) -> dict[str, tuple]:
"""Build field definitions for a Pydantic model from query parameters.""" """Build field definitions for a Pydantic model from query parameters."""
from typing import Any from typing import Any
field_definitions = {} field_definitions = {}
for param_name, param_type, default_value in query_parameters: for param_name, param_type, default_value, _ in query_parameters:
if use_any: if use_any:
field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value) field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)
continue continue
@ -108,10 +111,10 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_
field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value) field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)
# Ensure all parameters are included # Ensure all parameters are included
expected_params = {name for name, _, _ in query_parameters} expected_params = {name for name, _, _, _ in query_parameters}
missing = expected_params - set(field_definitions.keys()) missing = expected_params - set(field_definitions.keys())
if missing: if missing:
for param_name, _, default_value in query_parameters: for param_name, _, default_value, _ in query_parameters:
if param_name in missing: if param_name in missing:
field_definitions[param_name] = ( field_definitions[param_name] = (
Any, Any,
@ -126,7 +129,7 @@ def _create_dynamic_request_model(
webmethod, webmethod,
method_name: str, method_name: str,
http_method: str, http_method: str,
query_parameters: list[tuple[str, type, Any]], query_parameters: list[QueryParameter],
use_any: bool = False, use_any: bool = False,
variant_suffix: str | None = None, variant_suffix: str | None = None,
) -> type | None: ) -> type | None:
@ -143,12 +146,12 @@ def _create_dynamic_request_model(
def _build_signature_params( def _build_signature_params(
query_parameters: list[tuple[str, type, Any]], query_parameters: list[QueryParameter],
) -> tuple[list[inspect.Parameter], dict[str, type]]: ) -> tuple[list[inspect.Parameter], dict[str, type]]:
"""Build signature parameters and annotations from query parameters.""" """Build signature parameters and annotations from query parameters."""
signature_params = [] signature_params = []
param_annotations = {} param_annotations = {}
for param_name, param_type, default_value in query_parameters: for param_name, param_type, default_value, _ in query_parameters:
param_annotations[param_name] = param_type param_annotations[param_name] = param_type
signature_params.append( signature_params.append(
inspect.Parameter( inspect.Parameter(
@ -219,6 +222,19 @@ def _is_extra_body_field(metadata_item: Any) -> bool:
return isinstance(metadata_item, ExtraBodyField) return isinstance(metadata_item, ExtraBodyField)
def _should_embed_parameter(param_type: Any) -> bool:
"""Determine whether a parameter should be embedded (wrapped) in the request body."""
if get_origin(param_type) is Annotated:
args = get_args(param_type)
metadata = args[1:] if len(args) > 1 else []
for metadata_item in metadata:
if isinstance(metadata_item, FastAPIBody):
# FastAPI treats embed=None as False, so default to False when unset.
return bool(metadata_item.embed)
# Unannotated parameters default to embed=True through create_dynamic_typed_route.
return True
def _is_async_iterator_type(type_obj: Any) -> bool: def _is_async_iterator_type(type_obj: Any) -> bool:
"""Check if a type is AsyncIterator or AsyncIterable.""" """Check if a type is AsyncIterator or AsyncIterable."""
from collections.abc import AsyncIterable, AsyncIterator from collections.abc import AsyncIterable, AsyncIterator
@ -282,7 +298,7 @@ def _find_models_for_endpoint(
Returns: Returns:
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name) tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name)
where query_parameters is a list of (name, type, default_value) tuples where query_parameters is a list of (name, type, default_value, should_embed) tuples
and file_form_params is a list of inspect.Parameter objects for File()/Form() params and file_form_params is a list of inspect.Parameter objects for File()/Form() params
and streaming_response_model is the model for streaming responses (AsyncIterator content) and streaming_response_model is the model for streaming responses (AsyncIterator content)
""" """
@ -299,7 +315,7 @@ def _find_models_for_endpoint(
# Find request model and collect all body parameters # Find request model and collect all body parameters
request_model = None request_model = None
query_parameters = [] query_parameters: list[QueryParameter] = []
file_form_params = [] file_form_params = []
path_params = set() path_params = set()
extra_body_params = [] extra_body_params = []
@ -325,6 +341,7 @@ def _find_models_for_endpoint(
# Check if it's a File() or Form() parameter - these need special handling # Check if it's a File() or Form() parameter - these need special handling
param_type = param.annotation param_type = param.annotation
param_should_embed = _should_embed_parameter(param_type)
if _is_file_or_form_param(param_type): if _is_file_or_form_param(param_type):
# File() and Form() parameters must be in the function signature directly # File() and Form() parameters must be in the function signature directly
# They cannot be part of a Pydantic model # They cannot be part of a Pydantic model
@ -350,30 +367,14 @@ def _find_models_for_endpoint(
# Store as extra body parameter - exclude from request model # Store as extra body parameter - exclude from request model
extra_body_params.append((param_name, base_type, extra_body_description)) extra_body_params.append((param_name, base_type, extra_body_description))
continue continue
param_type = base_type
# Check if it's a Pydantic model (for POST/PUT requests) # Check if it's a Pydantic model (for POST/PUT requests)
if hasattr(param_type, "model_json_schema"): if hasattr(param_type, "model_json_schema"):
# Collect all body parameters including Pydantic models query_parameters.append((param_name, param_type, param.default, param_should_embed))
# We'll decide later whether to use a single model or create a combined one
query_parameters.append((param_name, param_type, param.default))
elif get_origin(param_type) is Annotated:
# Handle Annotated types - get the base type
args = get_args(param_type)
if args and hasattr(args[0], "model_json_schema"):
# Collect Pydantic models from Annotated types
query_parameters.append((param_name, args[0], param.default))
else: else:
# Regular annotated parameter (but not File/Form, already handled above) # Regular annotated parameter (but not File/Form, already handled above)
query_parameters.append((param_name, param_type, param.default)) query_parameters.append((param_name, param_type, param.default, param_should_embed))
else:
# This is likely a body parameter for POST/PUT or query parameter for GET
# Store the parameter info for later use
# Preserve inspect.Parameter.empty to distinguish "no default" from "default=None"
default_value = param.default
# Extract the base type from union types (e.g., str | None -> str)
# Also make it safe for FastAPI to avoid forward reference issues
query_parameters.append((param_name, param_type, default_value))
# Store extra body fields for later use in post-processing # Store extra body fields for later use in post-processing
# We'll store them when the endpoint is created, as we need the full path # We'll store them when the endpoint is created, as we need the full path
@ -385,8 +386,8 @@ def _find_models_for_endpoint(
# Otherwise, we'll create a combined request model from all parameters # Otherwise, we'll create a combined request model from all parameters
# BUT: For GET requests, never create a request body - all parameters should be query parameters # BUT: For GET requests, never create a request body - all parameters should be query parameters
if is_post_put and len(query_parameters) == 1: if is_post_put and len(query_parameters) == 1:
param_name, param_type, default_value = query_parameters[0] param_name, param_type, default_value, should_embed = query_parameters[0]
if hasattr(param_type, "model_json_schema"): if hasattr(param_type, "model_json_schema") and not should_embed:
request_model = param_type request_model = param_type
query_parameters = [] # Clear query_parameters so we use the single model query_parameters = [] # Clear query_parameters so we use the single model
@ -495,7 +496,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
if file_form_params and is_post_put: if file_form_params and is_post_put:
signature_params = list(file_form_params) signature_params = list(file_form_params)
param_annotations = {param.name: param.annotation for param in file_form_params} param_annotations = {param.name: param.annotation for param in file_form_params}
for param_name, param_type, default_value in query_parameters: for param_name, param_type, default_value, _ in query_parameters:
signature_params.append( signature_params.append(
inspect.Parameter( inspect.Parameter(
param_name, param_name,

View file

@ -48,16 +48,10 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if mcp_endpoint is None: if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_endpoint is required")
# Phase 1: Support both old header-based auth AND new authorization parameter # Get other headers from provider data (but NOT authorization)
# Get headers and auth from provider data (old approach) provider_headers = await self.get_headers_from_request(mcp_endpoint.uri)
provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
# New authorization parameter takes precedence over provider data return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization)
final_authorization = authorization or provider_auth
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=final_authorization
)
async def invoke_tool( async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
@ -69,39 +63,38 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"): if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
# Phase 1: Support both old header-based auth AND new authorization parameter # Get other headers from provider data (but NOT authorization)
# Get headers and auth from provider data (old approach) provider_headers = await self.get_headers_from_request(endpoint)
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await invoke_mcp_tool( return await invoke_mcp_tool(
endpoint=endpoint, endpoint=endpoint,
tool_name=tool_name, tool_name=tool_name,
kwargs=kwargs, kwargs=kwargs,
headers=provider_headers, headers=provider_headers,
authorization=final_authorization, authorization=authorization,
) )
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]: async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
""" """
Extract headers and authorization from request provider data (Phase 1 backward compatibility). Extract headers from request provider data, excluding authorization.
Phase 1: Temporarily allows Authorization to be passed via mcp_headers for backward compatibility. Authorization must be provided via the dedicated authorization parameter.
Phase 2: Will enforce that Authorization should use the dedicated authorization parameter instead. If Authorization is found in mcp_headers, raise an error to guide users to the correct approach.
Args:
mcp_endpoint_uri: The MCP endpoint URI to match against provider data
Returns: Returns:
Tuple of (headers_dict, authorization_token) dict[str, str]: Headers dictionary (without Authorization)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None Raises:
ValueError: If Authorization header is found in mcp_headers
""" """
def canonicalize_uri(uri: str) -> str: def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}" return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {} headers = {}
authorization = None
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers: if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers:
@ -109,17 +102,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue continue
# Phase 1: Extract Authorization from mcp_headers for backward compatibility # Reject Authorization in mcp_headers - must use authorization parameter
# (Phase 2 will reject this and require the dedicated authorization parameter)
for key in values.keys(): for key in values.keys():
if key.lower() == "authorization": if key.lower() == "authorization":
# Extract authorization token and strip "Bearer " prefix if present raise ValueError(
auth_value = values[key] "Authorization cannot be provided via mcp_headers in provider_data. "
if auth_value.startswith("Bearer "): "Please use the dedicated 'authorization' parameter instead. "
authorization = auth_value[7:] # Remove "Bearer " prefix "Example: tool_runtime.invoke_tool(..., authorization='your-token')"
else: )
authorization = auth_value
else:
headers[key] = values[key] headers[key] = values[key]
return headers, authorization return headers

View file

@ -9,8 +9,6 @@ Integration tests for inference/chat completion with JSON Schema-based tools.
Tests that tools pass through correctly to various LLM providers. Tests that tools pass through correctly to various LLM providers.
""" """
import json
import pytest import pytest
from llama_stack.core.library_client import LlamaStackAsLibraryClient from llama_stack.core.library_client import LlamaStackAsLibraryClient
@ -193,22 +191,11 @@ class TestMCPToolsInChatCompletion:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility) # Use the dedicated authorization parameter
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# Get the tools from MCP # Get the tools from MCP
tools_response = llama_stack_client.tool_runtime.list_tools( tools_response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Convert to OpenAI format for inference # Convert to OpenAI format for inference

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import pytest import pytest
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta
@ -37,32 +35,20 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility) # Use the dedicated authorization parameter (no more provider_data headers)
provider_data = { # This tests direct tool_runtime.invoke_tool API calls
"mcp_headers": { tools_list = llama_stack_client.tool_runtime.list_tools(
uri: { tool_group_id=test_toolgroup_id,
"Authorization": f"Bearer {AUTH_TOKEN}", authorization=AUTH_TOKEN, # Use dedicated authorization parameter
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers, # Use old header-based approach
) )
assert len(tools_list) == 2 assert len(tools_list) == 2
assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"} assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"}
# Invoke tool with authorization parameter
response = llama_stack_client.tool_runtime.invoke_tool( response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="greet_everyone", tool_name="greet_everyone",
kwargs=dict(url="https://www.google.com"), kwargs=dict(url="https://www.google.com"),
extra_headers=auth_headers, # Use old header-based approach authorization=AUTH_TOKEN, # Use dedicated authorization parameter
) )
content = response.content content = response.content
assert len(content) == 1 assert len(content) == 1

View file

@ -8,8 +8,6 @@
Tests $ref, $defs, and other JSON Schema features through MCP integration. Tests $ref, $defs, and other JSON Schema features through MCP integration.
""" """
import json
import pytest import pytest
from llama_stack.core.library_client import LlamaStackAsLibraryClient from llama_stack.core.library_client import LlamaStackAsLibraryClient
@ -122,22 +120,11 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility) # Use the dedicated authorization parameter
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List runtime tools # List runtime tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
tools = response tools = response
@ -173,22 +160,11 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility) # Use the dedicated authorization parameter
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List tools # List tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Find book_flight tool (which should have $ref/$defs) # Find book_flight tool (which should have $ref/$defs)
@ -230,21 +206,10 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility) # Use the dedicated authorization parameter
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Find get_weather tool # Find get_weather tool
@ -284,22 +249,10 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility) # Use the dedicated authorization parameter
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Invoke tool with complex nested data # Invoke tool with complex nested data
@ -311,7 +264,7 @@ class TestMCPToolInvocation:
"shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}}, "shipping": {"address": {"street": "123 Main St", "city": "San Francisco", "zipcode": "94102"}},
} }
}, },
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Should succeed without schema validation errors # Should succeed without schema validation errors
@ -337,29 +290,17 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility) # Use the dedicated authorization parameter
provider_data = {
"mcp_headers": {
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
# Test with email format # Test with email format
result_email = llama_stack_client.tool_runtime.invoke_tool( result_email = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact", tool_name="flexible_contact",
kwargs={"contact_info": "user@example.com"}, kwargs={"contact_info": "user@example.com"},
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
assert result_email.error_message is None assert result_email.error_message is None
@ -368,7 +309,7 @@ class TestMCPToolInvocation:
result_phone = llama_stack_client.tool_runtime.invoke_tool( result_phone = llama_stack_client.tool_runtime.invoke_tool(
tool_name="flexible_contact", tool_name="flexible_contact",
kwargs={"contact_info": "+15551234567"}, kwargs={"contact_info": "+15551234567"},
extra_headers=auth_headers, authorization=AUTH_TOKEN,
) )
assert result_phone.error_message is None assert result_phone.error_message is None
@ -400,21 +341,10 @@ class TestAgentWithMCPTools:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Use old header-based approach for Phase 1 (backward compatibility) # Use the dedicated authorization parameter
provider_data = { tools_list = llama_stack_client.tool_runtime.list_tools(
"mcp_headers": { tool_group_id=test_toolgroup_id,
uri: { authorization=AUTH_TOKEN,
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,
) )
tool_defs = [ tool_defs = [
{ {

4
uv.lock generated
View file

@ -2165,10 +2165,8 @@ requires-dist = [
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-multipart", specifier = ">=0.0.20" }, { name = "python-multipart", specifier = ">=0.0.20" },
{ name = "pyyaml", specifier = ">=6.0" }, { name = "pyyaml", specifier = ">=6.0" },
{ name = "pyyaml", specifier = ">=6.0.2" },
{ name = "rich" }, { name = "rich" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
{ name = "starlette" },
{ name = "starlette", specifier = ">=0.49.1" }, { name = "starlette", specifier = ">=0.49.1" },
{ name = "termcolor" }, { name = "termcolor" },
{ name = "tiktoken" }, { name = "tiktoken" },
@ -4656,6 +4654,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6b/fa/3234f913fe9a6525a7b97c6dad1f51e72b917e6872e051a5e2ffd8b16fbb/ruamel.yaml.clib-0.2.14-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:70eda7703b8126f5e52fcf276e6c0f40b0d314674f896fc58c47b0aef2b9ae83", size = 137970, upload-time = "2025-09-22T19:51:09.472Z" }, { url = "https://files.pythonhosted.org/packages/6b/fa/3234f913fe9a6525a7b97c6dad1f51e72b917e6872e051a5e2ffd8b16fbb/ruamel.yaml.clib-0.2.14-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:70eda7703b8126f5e52fcf276e6c0f40b0d314674f896fc58c47b0aef2b9ae83", size = 137970, upload-time = "2025-09-22T19:51:09.472Z" },
{ url = "https://files.pythonhosted.org/packages/ef/ec/4edbf17ac2c87fa0845dd366ef8d5852b96eb58fcd65fc1ecf5fe27b4641/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a0cb71ccc6ef9ce36eecb6272c81afdc2f565950cdcec33ae8e6cd8f7fc86f27", size = 739639, upload-time = "2025-09-22T19:51:10.566Z" }, { url = "https://files.pythonhosted.org/packages/ef/ec/4edbf17ac2c87fa0845dd366ef8d5852b96eb58fcd65fc1ecf5fe27b4641/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a0cb71ccc6ef9ce36eecb6272c81afdc2f565950cdcec33ae8e6cd8f7fc86f27", size = 739639, upload-time = "2025-09-22T19:51:10.566Z" },
{ url = "https://files.pythonhosted.org/packages/15/18/b0e1fafe59051de9e79cdd431863b03593ecfa8341c110affad7c8121efc/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e7cb9ad1d525d40f7d87b6df7c0ff916a66bc52cb61b66ac1b2a16d0c1b07640", size = 764456, upload-time = "2025-09-22T19:51:11.736Z" }, { url = "https://files.pythonhosted.org/packages/15/18/b0e1fafe59051de9e79cdd431863b03593ecfa8341c110affad7c8121efc/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e7cb9ad1d525d40f7d87b6df7c0ff916a66bc52cb61b66ac1b2a16d0c1b07640", size = 764456, upload-time = "2025-09-22T19:51:11.736Z" },
{ url = "https://files.pythonhosted.org/packages/e7/cd/150fdb96b8fab27fe08d8a59fe67554568727981806e6bc2677a16081ec7/ruamel_yaml_clib-0.2.14-cp314-cp314-win32.whl", hash = "sha256:9b4104bf43ca0cd4e6f738cb86326a3b2f6eef00f417bd1e7efb7bdffe74c539", size = 102394, upload-time = "2025-11-14T21:57:36.703Z" },
{ url = "https://files.pythonhosted.org/packages/bd/e6/a3fa40084558c7e1dc9546385f22a93949c890a8b2e445b2ba43935f51da/ruamel_yaml_clib-0.2.14-cp314-cp314-win_amd64.whl", hash = "sha256:13997d7d354a9890ea1ec5937a219817464e5cc344805b37671562a401ca3008", size = 122673, upload-time = "2025-11-14T21:57:38.177Z" },
] ]
[[package]] [[package]]