even more cleanup, the deltas should be much smaller now

This commit is contained in:
Ashwin Bharambe 2025-11-14 14:18:15 -08:00
parent 5293b4e5e9
commit 9deb0beb86
14 changed files with 5038 additions and 17435 deletions

View file

@ -19,6 +19,7 @@ from pydantic import Field, create_model
from llama_stack.log import get_logger
from llama_stack_api import Api
from llama_stack_api.schema_utils import get_registered_schema_info
from . import app as app_module
from .state import _extra_body_fields, register_dynamic_model
@ -31,23 +32,16 @@ def _to_pascal_case(segment: str) -> str:
return "".join(token.capitalize() for token in tokens if token)
def _compose_request_model_name(webmethod, http_method: str, variant: str | None = None) -> str:
segments = []
level = (webmethod.level or "").lower()
if level and level != "v1":
segments.append(_to_pascal_case(str(webmethod.level)))
for part in filter(None, webmethod.route.split("/")):
lower_part = part.lower()
if lower_part in {"v1", "v1alpha", "v1beta"}:
continue
if part.startswith("{"):
param = part[1:].split(":", 1)[0]
segments.append(f"By{_to_pascal_case(param)}")
else:
segments.append(_to_pascal_case(part))
if not segments:
segments.append("Root")
base_name = "".join(segments) + http_method.title() + "Request"
def _compose_request_model_name(api: Api, method_name: str, variant: str | None = None) -> str:
"""Generate a deterministic model name from the protocol method."""
def _to_pascal_from_snake(value: str) -> str:
return "".join(segment.capitalize() for segment in value.split("_") if segment)
base_name = _to_pascal_from_snake(method_name)
if not base_name:
base_name = _to_pascal_case(api.value)
base_name = f"{base_name}Request"
if variant:
base_name = f"{base_name}{variant}"
return base_name
@ -130,6 +124,7 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_
def _create_dynamic_request_model(
api: Api,
webmethod,
method_name: str,
http_method: str,
query_parameters: list[tuple[str, type, Any]],
use_any: bool = False,
@ -140,7 +135,7 @@ def _create_dynamic_request_model(
field_definitions = _build_field_definitions(query_parameters, use_any)
if not field_definitions:
return None
model_name = _compose_request_model_name(webmethod, http_method, variant_suffix)
model_name = _compose_request_model_name(api, method_name, variant_suffix or None)
request_model = create_model(model_name, **field_definitions)
return register_dynamic_model(model_name, request_model)
except Exception:
@ -261,9 +256,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
streaming_model = inner_type
else:
# Might be a registered schema - check if it's registered
from llama_stack_api.schema_utils import _registered_schemas
if inner_type in _registered_schemas:
if get_registered_schema_info(inner_type):
# We'll need to look this up later, but for now store the type
streaming_model = inner_type
elif hasattr(arg, "model_json_schema"):
@ -427,17 +420,28 @@ def _find_models_for_endpoint(
try:
from fastapi import Response as FastAPIResponse
except ImportError:
FastAPIResponse = None
fastapi_response_cls = None
else:
fastapi_response_cls = FastAPIResponse
try:
from starlette.responses import Response as StarletteResponse
except ImportError:
StarletteResponse = None
starlette_response_cls = None
else:
starlette_response_cls = StarletteResponse
response_types = tuple(t for t in (FastAPIResponse, StarletteResponse) if t is not None)
response_types = tuple(t for t in (fastapi_response_cls, starlette_response_cls) if t is not None)
if response_types and any(return_annotation is t for t in response_types):
response_schema_name = "Response"
return request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name
return (
request_model,
response_model,
query_parameters,
file_form_params,
streaming_response_model,
response_schema_name,
)
except Exception as exc:
logger.warning(
@ -465,9 +469,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
file_form_params,
streaming_response_model,
response_schema_name,
) = (
_find_models_for_endpoint(webmethod, api, name, is_post_put)
)
) = _find_models_for_endpoint(webmethod, api, name, is_post_put)
operation_description = _extract_operation_description_from_docstring(api, name)
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
@ -479,6 +481,17 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
key = (fastapi_path, method.upper())
_extra_body_fields[key] = extra_body_params
if is_post_put and not request_model and not file_form_params and query_parameters:
request_model = _create_dynamic_request_model(
api, webmethod, name, primary_method, query_parameters, use_any=False
)
if not request_model:
request_model = _create_dynamic_request_model(
api, webmethod, name, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
)
if request_model:
query_parameters = []
if file_form_params and is_post_put:
signature_params = list(file_form_params)
param_annotations = {param.name: param.annotation for param in file_form_params}
@ -503,12 +516,16 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
endpoint_func = file_form_endpoint
elif request_model and response_model:
endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
elif request_model:
endpoint_func = _create_endpoint_with_request_model(request_model, None, operation_description)
elif response_model and query_parameters:
if is_post_put:
request_model = _create_dynamic_request_model(api, webmethod, primary_method, query_parameters, use_any=False)
request_model = _create_dynamic_request_model(
api, webmethod, name, primary_method, query_parameters, use_any=False
)
if not request_model:
request_model = _create_dynamic_request_model(
api, webmethod, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
api, webmethod, name, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
)
if request_model:
@ -600,10 +617,8 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
streaming_schema_name = None
# Check if it's a registered schema first (before checking __name__)
# because registered schemas might be Annotated types
from llama_stack_api.schema_utils import _registered_schemas
if streaming_response_model in _registered_schemas:
streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
if schema_info := get_registered_schema_info(streaming_response_model):
streaming_schema_name = schema_info.name
elif hasattr(streaming_response_model, "__name__"):
streaming_schema_name = streaming_response_model.__name__