mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: re-add x-llama-stack-extra-body-params
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
c4cad890cc
commit
66056ddb87
6 changed files with 842 additions and 339 deletions
|
|
@ -36,6 +36,10 @@ _dynamic_models = []
|
|||
# Cache for protocol methods to avoid repeated lookups
|
||||
_protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
|
||||
|
||||
# Global dict to store extra body field information by endpoint
|
||||
# Key: (path, method) tuple, Value: list of (param_name, param_type, description) tuples
|
||||
_extra_body_fields: dict[tuple[str, str], list[tuple[str, type, str | None]]] = {}
|
||||
|
||||
|
||||
def create_llama_stack_app() -> FastAPI:
|
||||
"""
|
||||
|
|
@ -238,6 +242,15 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
|
||||
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
|
||||
|
||||
# Retrieve and store extra body fields for this endpoint
|
||||
func = _get_protocol_method(api, name)
|
||||
extra_body_params = getattr(func, "_extra_body_params", []) if func else []
|
||||
if extra_body_params:
|
||||
global _extra_body_fields
|
||||
for method in methods:
|
||||
key = (fastapi_path, method.upper())
|
||||
_extra_body_fields[key] = extra_body_params
|
||||
|
||||
if file_form_params and is_post_put:
|
||||
signature_params = list(file_form_params)
|
||||
param_annotations = {param.name: param.annotation for param in file_form_params}
|
||||
|
|
@ -402,6 +415,13 @@ def _is_file_or_form_param(param_type: Any) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _is_extra_body_field(metadata_item: Any) -> bool:
|
||||
"""Check if a metadata item is an ExtraBodyField instance."""
|
||||
from llama_stack.schema_utils import ExtraBodyField
|
||||
|
||||
return isinstance(metadata_item, ExtraBodyField)
|
||||
|
||||
|
||||
def _find_models_for_endpoint(
|
||||
webmethod, api: Api, method_name: str
|
||||
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter]]:
|
||||
|
|
@ -433,6 +453,7 @@ def _find_models_for_endpoint(
|
|||
query_parameters = []
|
||||
file_form_params = []
|
||||
path_params = set()
|
||||
extra_body_params = []
|
||||
|
||||
# Extract path parameters from the route
|
||||
if webmethod and hasattr(webmethod, "route"):
|
||||
|
|
@ -462,6 +483,26 @@ def _find_models_for_endpoint(
|
|||
file_form_params.append(param)
|
||||
continue
|
||||
|
||||
# Check for ExtraBodyField in Annotated types
|
||||
is_extra_body = False
|
||||
extra_body_description = None
|
||||
if get_origin(param_type) is Annotated:
|
||||
args = get_args(param_type)
|
||||
base_type = args[0] if args else param_type
|
||||
metadata = args[1:] if len(args) > 1 else []
|
||||
|
||||
# Check if any metadata item is an ExtraBodyField
|
||||
for metadata_item in metadata:
|
||||
if _is_extra_body_field(metadata_item):
|
||||
is_extra_body = True
|
||||
extra_body_description = metadata_item.description
|
||||
break
|
||||
|
||||
if is_extra_body:
|
||||
# Store as extra body parameter - exclude from request model
|
||||
extra_body_params.append((param_name, base_type, extra_body_description))
|
||||
continue
|
||||
|
||||
# Check if it's a Pydantic model (for POST/PUT requests)
|
||||
if hasattr(param_type, "model_json_schema"):
|
||||
# Collect all body parameters including Pydantic models
|
||||
|
|
@ -486,6 +527,12 @@ def _find_models_for_endpoint(
|
|||
# Also make it safe for FastAPI to avoid forward reference issues
|
||||
query_parameters.append((param_name, param_type, default_value))
|
||||
|
||||
# Store extra body fields for later use in post-processing
|
||||
# We'll store them when the endpoint is created, as we need the full path
|
||||
# For now, attach to the function for later retrieval
|
||||
if extra_body_params:
|
||||
func._extra_body_params = extra_body_params # type: ignore
|
||||
|
||||
# If there's exactly one body parameter and it's a Pydantic model, use it directly
|
||||
# Otherwise, we'll create a combined request model from all parameters
|
||||
if len(query_parameters) == 1:
|
||||
|
|
@ -965,6 +1012,100 @@ def _clean_schema_descriptions(openapi_schema: dict[str, Any]) -> dict[str, Any]
|
|||
return openapi_schema
|
||||
|
||||
|
||||
def _add_extra_body_params_extension(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Add x-llama-stack-extra-body-params extension to requestBody for endpoints with ExtraBodyField parameters.
|
||||
"""
|
||||
if "paths" not in openapi_schema:
|
||||
return openapi_schema
|
||||
|
||||
global _extra_body_fields
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
for path, path_item in openapi_schema["paths"].items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if method not in path_item:
|
||||
continue
|
||||
|
||||
operation = path_item[method]
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
# Check if we have extra body fields for this path/method
|
||||
key = (path, method.upper())
|
||||
if key not in _extra_body_fields:
|
||||
continue
|
||||
|
||||
extra_body_params = _extra_body_fields[key]
|
||||
|
||||
# Ensure requestBody exists
|
||||
if "requestBody" not in operation:
|
||||
continue
|
||||
|
||||
request_body = operation["requestBody"]
|
||||
if not isinstance(request_body, dict):
|
||||
continue
|
||||
|
||||
# Get the schema from requestBody
|
||||
content = request_body.get("content", {})
|
||||
json_content = content.get("application/json", {})
|
||||
schema_ref = json_content.get("schema", {})
|
||||
|
||||
# Remove extra body fields from the schema if they exist as properties
|
||||
# Handle both $ref schemas and inline schemas
|
||||
if isinstance(schema_ref, dict):
|
||||
if "$ref" in schema_ref:
|
||||
# Schema is a reference - remove from the referenced schema
|
||||
ref_path = schema_ref["$ref"]
|
||||
if ref_path.startswith("#/components/schemas/"):
|
||||
schema_name = ref_path.split("/")[-1]
|
||||
if "components" in openapi_schema and "schemas" in openapi_schema["components"]:
|
||||
schema_def = openapi_schema["components"]["schemas"].get(schema_name)
|
||||
if isinstance(schema_def, dict) and "properties" in schema_def:
|
||||
for param_name, _, _ in extra_body_params:
|
||||
if param_name in schema_def["properties"]:
|
||||
del schema_def["properties"][param_name]
|
||||
# Also remove from required if present
|
||||
if "required" in schema_def and param_name in schema_def["required"]:
|
||||
schema_def["required"].remove(param_name)
|
||||
elif "properties" in schema_ref:
|
||||
# Schema is inline - remove directly from it
|
||||
for param_name, _, _ in extra_body_params:
|
||||
if param_name in schema_ref["properties"]:
|
||||
del schema_ref["properties"][param_name]
|
||||
# Also remove from required if present
|
||||
if "required" in schema_ref and param_name in schema_ref["required"]:
|
||||
schema_ref["required"].remove(param_name)
|
||||
|
||||
# Build the extra body params schema
|
||||
extra_params_schema = {}
|
||||
for param_name, param_type, description in extra_body_params:
|
||||
try:
|
||||
# Generate JSON schema for the parameter type
|
||||
adapter = TypeAdapter(param_type)
|
||||
param_schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
|
||||
|
||||
# Add description if provided
|
||||
if description:
|
||||
param_schema["description"] = description
|
||||
|
||||
extra_params_schema[param_name] = param_schema
|
||||
except Exception:
|
||||
# If we can't generate schema, skip this parameter
|
||||
continue
|
||||
|
||||
if extra_params_schema:
|
||||
# Add the extension to requestBody
|
||||
if "x-llama-stack-extra-body-params" not in request_body:
|
||||
request_body["x-llama-stack-extra-body-params"] = extra_params_schema
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _remove_query_params_from_body_endpoints(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Remove query parameters from POST/PUT/PATCH endpoints that have a request body.
|
||||
|
|
@ -1399,6 +1540,9 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
|||
# FastAPI sometimes infers parameters as query params even when they should be in the request body
|
||||
openapi_schema = _remove_query_params_from_body_endpoints(openapi_schema)
|
||||
|
||||
# Add x-llama-stack-extra-body-params extension for ExtraBodyField parameters
|
||||
openapi_schema = _add_extra_body_params_extension(openapi_schema)
|
||||
|
||||
# Split into stable (v1 only), experimental (v1alpha + v1beta), deprecated, and combined (stainless) specs
|
||||
# Each spec needs its own deep copy of the full schema to avoid cross-contamination
|
||||
import copy
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue