schema naming cleanup, should be much closer

This commit is contained in:
Ashwin Bharambe 2025-11-14 13:07:34 -08:00
parent 69e1176ff8
commit 5293b4e5e9
11 changed files with 20459 additions and 12557 deletions

View file

@ -12,16 +12,45 @@ import inspect
import re
import types
import typing
import uuid
from typing import Annotated, Any, get_args, get_origin
from fastapi import FastAPI
from pydantic import Field, create_model
from llama_stack.log import get_logger
from llama_stack_api import Api
from . import app as app_module
from .state import _dynamic_models, _extra_body_fields
from .state import _extra_body_fields, register_dynamic_model
logger = get_logger(name=__name__, category="core")
def _to_pascal_case(segment: str) -> str:
tokens = re.findall(r"[A-Za-z]+|\d+", segment)
return "".join(token.capitalize() for token in tokens if token)
def _compose_request_model_name(webmethod, http_method: str, variant: str | None = None) -> str:
segments = []
level = (webmethod.level or "").lower()
if level and level != "v1":
segments.append(_to_pascal_case(str(webmethod.level)))
for part in filter(None, webmethod.route.split("/")):
lower_part = part.lower()
if lower_part in {"v1", "v1alpha", "v1beta"}:
continue
if part.startswith("{"):
param = part[1:].split(":", 1)[0]
segments.append(f"By{_to_pascal_case(param)}")
else:
segments.append(_to_pascal_case(part))
if not segments:
segments.append("Root")
base_name = "".join(segments) + http_method.title() + "Request"
if variant:
base_name = f"{base_name}{variant}"
return base_name
def _extract_path_parameters(path: str) -> list[dict[str, Any]]:
@ -99,21 +128,21 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_
def _create_dynamic_request_model(
webmethod, query_parameters: list[tuple[str, type, Any]], use_any: bool = False, add_uuid: bool = False
api: Api,
webmethod,
http_method: str,
query_parameters: list[tuple[str, type, Any]],
use_any: bool = False,
variant_suffix: str | None = None,
) -> type | None:
"""Create a dynamic Pydantic model for request body."""
try:
field_definitions = _build_field_definitions(query_parameters, use_any)
if not field_definitions:
return None
clean_route = webmethod.route.replace("/", "_").replace("{", "").replace("}", "").replace("-", "_")
model_name = f"{clean_route}_Request"
if add_uuid:
model_name = f"{model_name}_{uuid.uuid4().hex[:8]}"
model_name = _compose_request_model_name(webmethod, http_method, variant_suffix)
request_model = create_model(model_name, **field_definitions)
_dynamic_models.append(request_model)
return request_model
return register_dynamic_model(model_name, request_model)
except Exception:
return None
@ -190,7 +219,7 @@ def _is_file_or_form_param(param_type: Any) -> bool:
def _is_extra_body_field(metadata_item: Any) -> bool:
"""Check if a metadata item is an ExtraBodyField instance."""
from llama_stack.schema_utils import ExtraBodyField
from llama_stack_api.schema_utils import ExtraBodyField
return isinstance(metadata_item, ExtraBodyField)
@ -232,7 +261,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
streaming_model = inner_type
else:
# Might be a registered schema - check if it's registered
from llama_stack.schema_utils import _registered_schemas
from llama_stack_api.schema_utils import _registered_schemas
if inner_type in _registered_schemas:
# We'll need to look this up later, but for now store the type
@ -247,7 +276,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
def _find_models_for_endpoint(
webmethod, api: Api, method_name: str, is_post_put: bool = False
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None]:
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None, str | None]:
"""
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
This uses the protocol function to determine the correct models dynamically.
@ -259,16 +288,18 @@ def _find_models_for_endpoint(
is_post_put: Whether this is a POST, PUT, or PATCH request (GET requests should never have request bodies)
Returns:
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model)
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name)
where query_parameters is a list of (name, type, default_value) tuples
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
and streaming_response_model is the model for streaming responses (AsyncIterator content)
"""
route_descriptor = f"{webmethod.method or 'UNKNOWN'} {webmethod.route}"
try:
# Get the function from the protocol
func = app_module._get_protocol_method(api, method_name)
if not func:
return None, None, [], [], None
logger.warning("No protocol method for %s.%s (%s)", api, method_name, route_descriptor)
return None, None, [], [], None, None
# Analyze the function signature
sig = inspect.signature(func)
@ -279,6 +310,7 @@ def _find_models_for_endpoint(
file_form_params = []
path_params = set()
extra_body_params = []
response_schema_name = None
# Extract path parameters from the route
if webmethod and hasattr(webmethod, "route"):
@ -391,23 +423,49 @@ def _find_models_for_endpoint(
elif origin is not None and (origin is types.UnionType or origin is typing.Union):
# Handle union types - extract both non-streaming and streaming models
response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
else:
try:
from fastapi import Response as FastAPIResponse
except ImportError:
FastAPIResponse = None
try:
from starlette.responses import Response as StarletteResponse
except ImportError:
StarletteResponse = None
return request_model, response_model, query_parameters, file_form_params, streaming_response_model
response_types = tuple(t for t in (FastAPIResponse, StarletteResponse) if t is not None)
if response_types and any(return_annotation is t for t in response_types):
response_schema_name = "Response"
except Exception:
# If we can't analyze the function signature, return None
return None, None, [], [], None
return request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name
except Exception as exc:
logger.warning(
"Failed to analyze endpoint %s.%s (%s): %s", api, method_name, route_descriptor, exc, exc_info=True
)
return None, None, [], [], None, None
def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
"""Create a FastAPI endpoint from a discovered route and webmethod."""
path = route.path
methods = route.methods
raw_methods = route.methods or set()
method_list = sorted({method.upper() for method in raw_methods if method and method.upper() != "HEAD"})
if not method_list:
method_list = ["GET"]
primary_method = method_list[0]
name = route.name
fastapi_path = path.replace("{", "{").replace("}", "}")
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
is_post_put = any(method in ["POST", "PUT", "PATCH"] for method in method_list)
request_model, response_model, query_parameters, file_form_params, streaming_response_model = (
(
request_model,
response_model,
query_parameters,
file_form_params,
streaming_response_model,
response_schema_name,
) = (
_find_models_for_endpoint(webmethod, api, name, is_post_put)
)
operation_description = _extract_operation_description_from_docstring(api, name)
@ -417,7 +475,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
func = app_module._get_protocol_method(api, name)
extra_body_params = getattr(func, "_extra_body_params", []) if func else []
if extra_body_params:
for method in methods:
for method in method_list:
key = (fastapi_path, method.upper())
_extra_body_fields[key] = extra_body_params
@ -447,12 +505,11 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
elif response_model and query_parameters:
if is_post_put:
# Try creating request model with type preservation, fallback to Any, then minimal
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=False)
request_model = _create_dynamic_request_model(api, webmethod, primary_method, query_parameters, use_any=False)
if not request_model:
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True)
if not request_model:
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True, add_uuid=True)
request_model = _create_dynamic_request_model(
api, webmethod, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
)
if request_model:
endpoint_func = _create_endpoint_with_request_model(
@ -532,16 +589,18 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
endpoint_func = no_params_endpoint
# Build response content with both application/json and text/event-stream if streaming
response_content = {}
response_content: dict[str, Any] = {}
if response_model:
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
elif response_schema_name:
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_schema_name}"}}
if streaming_response_model:
# Get the schema name for the streaming model
# It might be a registered schema or a Pydantic model
streaming_schema_name = None
# Check if it's a registered schema first (before checking __name__)
# because registered schemas might be Annotated types
from llama_stack.schema_utils import _registered_schemas
from llama_stack_api.schema_utils import _registered_schemas
if streaming_response_model in _registered_schemas:
streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
@ -554,9 +613,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
}
# If no content types, use empty schema
if not response_content:
response_content["application/json"] = {"schema": {}}
# Add the endpoint to the FastAPI app
is_deprecated = webmethod.deprecated or False
route_kwargs = {
@ -564,16 +620,16 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
"tags": [_get_tag_from_api(api)],
"deprecated": is_deprecated,
"responses": {
200: {
"description": response_description,
"content": response_content,
},
400: {"$ref": "#/components/responses/BadRequest400"},
429: {"$ref": "#/components/responses/TooManyRequests429"},
500: {"$ref": "#/components/responses/InternalServerError500"},
"default": {"$ref": "#/components/responses/DefaultError"},
},
}
success_response: dict[str, Any] = {"description": response_description}
if response_content:
success_response["content"] = response_content
route_kwargs["responses"][200] = success_response
# FastAPI needs response_model parameter to properly generate OpenAPI spec
# Use the non-streaming response model if available
@ -581,6 +637,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
route_kwargs["response_model"] = response_model
method_map = {"GET": app.get, "POST": app.post, "PUT": app.put, "DELETE": app.delete, "PATCH": app.patch}
for method in methods:
if handler := method_map.get(method.upper()):
for method in method_list:
if handler := method_map.get(method):
handler(fastapi_path, **route_kwargs)(endpoint_func)