mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
schema naming cleanup, should be much closer
This commit is contained in:
parent
69e1176ff8
commit
5293b4e5e9
11 changed files with 20459 additions and 12557 deletions
|
|
@ -12,16 +12,45 @@ import inspect
|
|||
import re
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from typing import Annotated, Any, get_args, get_origin
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import Field, create_model
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Api
|
||||
|
||||
from . import app as app_module
|
||||
from .state import _dynamic_models, _extra_body_fields
|
||||
from .state import _extra_body_fields, register_dynamic_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def _to_pascal_case(segment: str) -> str:
|
||||
tokens = re.findall(r"[A-Za-z]+|\d+", segment)
|
||||
return "".join(token.capitalize() for token in tokens if token)
|
||||
|
||||
|
||||
def _compose_request_model_name(webmethod, http_method: str, variant: str | None = None) -> str:
|
||||
segments = []
|
||||
level = (webmethod.level or "").lower()
|
||||
if level and level != "v1":
|
||||
segments.append(_to_pascal_case(str(webmethod.level)))
|
||||
for part in filter(None, webmethod.route.split("/")):
|
||||
lower_part = part.lower()
|
||||
if lower_part in {"v1", "v1alpha", "v1beta"}:
|
||||
continue
|
||||
if part.startswith("{"):
|
||||
param = part[1:].split(":", 1)[0]
|
||||
segments.append(f"By{_to_pascal_case(param)}")
|
||||
else:
|
||||
segments.append(_to_pascal_case(part))
|
||||
if not segments:
|
||||
segments.append("Root")
|
||||
base_name = "".join(segments) + http_method.title() + "Request"
|
||||
if variant:
|
||||
base_name = f"{base_name}{variant}"
|
||||
return base_name
|
||||
|
||||
|
||||
def _extract_path_parameters(path: str) -> list[dict[str, Any]]:
|
||||
|
|
@ -99,21 +128,21 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_
|
|||
|
||||
|
||||
def _create_dynamic_request_model(
|
||||
webmethod, query_parameters: list[tuple[str, type, Any]], use_any: bool = False, add_uuid: bool = False
|
||||
api: Api,
|
||||
webmethod,
|
||||
http_method: str,
|
||||
query_parameters: list[tuple[str, type, Any]],
|
||||
use_any: bool = False,
|
||||
variant_suffix: str | None = None,
|
||||
) -> type | None:
|
||||
"""Create a dynamic Pydantic model for request body."""
|
||||
try:
|
||||
field_definitions = _build_field_definitions(query_parameters, use_any)
|
||||
if not field_definitions:
|
||||
return None
|
||||
clean_route = webmethod.route.replace("/", "_").replace("{", "").replace("}", "").replace("-", "_")
|
||||
model_name = f"{clean_route}_Request"
|
||||
if add_uuid:
|
||||
model_name = f"{model_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
model_name = _compose_request_model_name(webmethod, http_method, variant_suffix)
|
||||
request_model = create_model(model_name, **field_definitions)
|
||||
_dynamic_models.append(request_model)
|
||||
return request_model
|
||||
return register_dynamic_model(model_name, request_model)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
@ -190,7 +219,7 @@ def _is_file_or_form_param(param_type: Any) -> bool:
|
|||
|
||||
def _is_extra_body_field(metadata_item: Any) -> bool:
|
||||
"""Check if a metadata item is an ExtraBodyField instance."""
|
||||
from llama_stack.schema_utils import ExtraBodyField
|
||||
from llama_stack_api.schema_utils import ExtraBodyField
|
||||
|
||||
return isinstance(metadata_item, ExtraBodyField)
|
||||
|
||||
|
|
@ -232,7 +261,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
|
|||
streaming_model = inner_type
|
||||
else:
|
||||
# Might be a registered schema - check if it's registered
|
||||
from llama_stack.schema_utils import _registered_schemas
|
||||
from llama_stack_api.schema_utils import _registered_schemas
|
||||
|
||||
if inner_type in _registered_schemas:
|
||||
# We'll need to look this up later, but for now store the type
|
||||
|
|
@ -247,7 +276,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
|
|||
|
||||
def _find_models_for_endpoint(
|
||||
webmethod, api: Api, method_name: str, is_post_put: bool = False
|
||||
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None]:
|
||||
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None, str | None]:
|
||||
"""
|
||||
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
|
||||
This uses the protocol function to determine the correct models dynamically.
|
||||
|
|
@ -259,16 +288,18 @@ def _find_models_for_endpoint(
|
|||
is_post_put: Whether this is a POST, PUT, or PATCH request (GET requests should never have request bodies)
|
||||
|
||||
Returns:
|
||||
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model)
|
||||
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name)
|
||||
where query_parameters is a list of (name, type, default_value) tuples
|
||||
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
|
||||
and streaming_response_model is the model for streaming responses (AsyncIterator content)
|
||||
"""
|
||||
route_descriptor = f"{webmethod.method or 'UNKNOWN'} {webmethod.route}"
|
||||
try:
|
||||
# Get the function from the protocol
|
||||
func = app_module._get_protocol_method(api, method_name)
|
||||
if not func:
|
||||
return None, None, [], [], None
|
||||
logger.warning("No protocol method for %s.%s (%s)", api, method_name, route_descriptor)
|
||||
return None, None, [], [], None, None
|
||||
|
||||
# Analyze the function signature
|
||||
sig = inspect.signature(func)
|
||||
|
|
@ -279,6 +310,7 @@ def _find_models_for_endpoint(
|
|||
file_form_params = []
|
||||
path_params = set()
|
||||
extra_body_params = []
|
||||
response_schema_name = None
|
||||
|
||||
# Extract path parameters from the route
|
||||
if webmethod and hasattr(webmethod, "route"):
|
||||
|
|
@ -391,23 +423,49 @@ def _find_models_for_endpoint(
|
|||
elif origin is not None and (origin is types.UnionType or origin is typing.Union):
|
||||
# Handle union types - extract both non-streaming and streaming models
|
||||
response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
|
||||
else:
|
||||
try:
|
||||
from fastapi import Response as FastAPIResponse
|
||||
except ImportError:
|
||||
FastAPIResponse = None
|
||||
try:
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
except ImportError:
|
||||
StarletteResponse = None
|
||||
|
||||
return request_model, response_model, query_parameters, file_form_params, streaming_response_model
|
||||
response_types = tuple(t for t in (FastAPIResponse, StarletteResponse) if t is not None)
|
||||
if response_types and any(return_annotation is t for t in response_types):
|
||||
response_schema_name = "Response"
|
||||
|
||||
except Exception:
|
||||
# If we can't analyze the function signature, return None
|
||||
return None, None, [], [], None
|
||||
return request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to analyze endpoint %s.%s (%s): %s", api, method_name, route_descriptor, exc, exc_info=True
|
||||
)
|
||||
return None, None, [], [], None, None
|
||||
|
||||
|
||||
def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
||||
"""Create a FastAPI endpoint from a discovered route and webmethod."""
|
||||
path = route.path
|
||||
methods = route.methods
|
||||
raw_methods = route.methods or set()
|
||||
method_list = sorted({method.upper() for method in raw_methods if method and method.upper() != "HEAD"})
|
||||
if not method_list:
|
||||
method_list = ["GET"]
|
||||
primary_method = method_list[0]
|
||||
name = route.name
|
||||
fastapi_path = path.replace("{", "{").replace("}", "}")
|
||||
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
|
||||
is_post_put = any(method in ["POST", "PUT", "PATCH"] for method in method_list)
|
||||
|
||||
request_model, response_model, query_parameters, file_form_params, streaming_response_model = (
|
||||
(
|
||||
request_model,
|
||||
response_model,
|
||||
query_parameters,
|
||||
file_form_params,
|
||||
streaming_response_model,
|
||||
response_schema_name,
|
||||
) = (
|
||||
_find_models_for_endpoint(webmethod, api, name, is_post_put)
|
||||
)
|
||||
operation_description = _extract_operation_description_from_docstring(api, name)
|
||||
|
|
@ -417,7 +475,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
func = app_module._get_protocol_method(api, name)
|
||||
extra_body_params = getattr(func, "_extra_body_params", []) if func else []
|
||||
if extra_body_params:
|
||||
for method in methods:
|
||||
for method in method_list:
|
||||
key = (fastapi_path, method.upper())
|
||||
_extra_body_fields[key] = extra_body_params
|
||||
|
||||
|
|
@ -447,12 +505,11 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
|
||||
elif response_model and query_parameters:
|
||||
if is_post_put:
|
||||
# Try creating request model with type preservation, fallback to Any, then minimal
|
||||
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=False)
|
||||
request_model = _create_dynamic_request_model(api, webmethod, primary_method, query_parameters, use_any=False)
|
||||
if not request_model:
|
||||
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True)
|
||||
if not request_model:
|
||||
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True, add_uuid=True)
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
|
||||
)
|
||||
|
||||
if request_model:
|
||||
endpoint_func = _create_endpoint_with_request_model(
|
||||
|
|
@ -532,16 +589,18 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
endpoint_func = no_params_endpoint
|
||||
|
||||
# Build response content with both application/json and text/event-stream if streaming
|
||||
response_content = {}
|
||||
response_content: dict[str, Any] = {}
|
||||
if response_model:
|
||||
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
|
||||
elif response_schema_name:
|
||||
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_schema_name}"}}
|
||||
if streaming_response_model:
|
||||
# Get the schema name for the streaming model
|
||||
# It might be a registered schema or a Pydantic model
|
||||
streaming_schema_name = None
|
||||
# Check if it's a registered schema first (before checking __name__)
|
||||
# because registered schemas might be Annotated types
|
||||
from llama_stack.schema_utils import _registered_schemas
|
||||
from llama_stack_api.schema_utils import _registered_schemas
|
||||
|
||||
if streaming_response_model in _registered_schemas:
|
||||
streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
|
||||
|
|
@ -554,9 +613,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
}
|
||||
|
||||
# If no content types, use empty schema
|
||||
if not response_content:
|
||||
response_content["application/json"] = {"schema": {}}
|
||||
|
||||
# Add the endpoint to the FastAPI app
|
||||
is_deprecated = webmethod.deprecated or False
|
||||
route_kwargs = {
|
||||
|
|
@ -564,16 +620,16 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
"tags": [_get_tag_from_api(api)],
|
||||
"deprecated": is_deprecated,
|
||||
"responses": {
|
||||
200: {
|
||||
"description": response_description,
|
||||
"content": response_content,
|
||||
},
|
||||
400: {"$ref": "#/components/responses/BadRequest400"},
|
||||
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
||||
500: {"$ref": "#/components/responses/InternalServerError500"},
|
||||
"default": {"$ref": "#/components/responses/DefaultError"},
|
||||
},
|
||||
}
|
||||
success_response: dict[str, Any] = {"description": response_description}
|
||||
if response_content:
|
||||
success_response["content"] = response_content
|
||||
route_kwargs["responses"][200] = success_response
|
||||
|
||||
# FastAPI needs response_model parameter to properly generate OpenAPI spec
|
||||
# Use the non-streaming response model if available
|
||||
|
|
@ -581,6 +637,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
route_kwargs["response_model"] = response_model
|
||||
|
||||
method_map = {"GET": app.get, "POST": app.post, "PUT": app.put, "DELETE": app.delete, "PATCH": app.patch}
|
||||
for method in methods:
|
||||
if handler := method_map.get(method.upper()):
|
||||
for method in method_list:
|
||||
if handler := method_map.get(method):
|
||||
handler(fastapi_path, **route_kwargs)(endpoint_func)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue