mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
even more cleanup, the deltas should be much smaller now
This commit is contained in:
parent
5293b4e5e9
commit
9deb0beb86
14 changed files with 5038 additions and 17435 deletions
|
|
@ -19,6 +19,7 @@ from pydantic import Field, create_model
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Api
|
||||
from llama_stack_api.schema_utils import get_registered_schema_info
|
||||
|
||||
from . import app as app_module
|
||||
from .state import _extra_body_fields, register_dynamic_model
|
||||
|
|
@ -31,23 +32,16 @@ def _to_pascal_case(segment: str) -> str:
|
|||
return "".join(token.capitalize() for token in tokens if token)
|
||||
|
||||
|
||||
def _compose_request_model_name(webmethod, http_method: str, variant: str | None = None) -> str:
|
||||
segments = []
|
||||
level = (webmethod.level or "").lower()
|
||||
if level and level != "v1":
|
||||
segments.append(_to_pascal_case(str(webmethod.level)))
|
||||
for part in filter(None, webmethod.route.split("/")):
|
||||
lower_part = part.lower()
|
||||
if lower_part in {"v1", "v1alpha", "v1beta"}:
|
||||
continue
|
||||
if part.startswith("{"):
|
||||
param = part[1:].split(":", 1)[0]
|
||||
segments.append(f"By{_to_pascal_case(param)}")
|
||||
else:
|
||||
segments.append(_to_pascal_case(part))
|
||||
if not segments:
|
||||
segments.append("Root")
|
||||
base_name = "".join(segments) + http_method.title() + "Request"
|
||||
def _compose_request_model_name(api: Api, method_name: str, variant: str | None = None) -> str:
|
||||
"""Generate a deterministic model name from the protocol method."""
|
||||
|
||||
def _to_pascal_from_snake(value: str) -> str:
|
||||
return "".join(segment.capitalize() for segment in value.split("_") if segment)
|
||||
|
||||
base_name = _to_pascal_from_snake(method_name)
|
||||
if not base_name:
|
||||
base_name = _to_pascal_case(api.value)
|
||||
base_name = f"{base_name}Request"
|
||||
if variant:
|
||||
base_name = f"{base_name}{variant}"
|
||||
return base_name
|
||||
|
|
@ -130,6 +124,7 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_
|
|||
def _create_dynamic_request_model(
|
||||
api: Api,
|
||||
webmethod,
|
||||
method_name: str,
|
||||
http_method: str,
|
||||
query_parameters: list[tuple[str, type, Any]],
|
||||
use_any: bool = False,
|
||||
|
|
@ -140,7 +135,7 @@ def _create_dynamic_request_model(
|
|||
field_definitions = _build_field_definitions(query_parameters, use_any)
|
||||
if not field_definitions:
|
||||
return None
|
||||
model_name = _compose_request_model_name(webmethod, http_method, variant_suffix)
|
||||
model_name = _compose_request_model_name(api, method_name, variant_suffix or None)
|
||||
request_model = create_model(model_name, **field_definitions)
|
||||
return register_dynamic_model(model_name, request_model)
|
||||
except Exception:
|
||||
|
|
@ -261,9 +256,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
|
|||
streaming_model = inner_type
|
||||
else:
|
||||
# Might be a registered schema - check if it's registered
|
||||
from llama_stack_api.schema_utils import _registered_schemas
|
||||
|
||||
if inner_type in _registered_schemas:
|
||||
if get_registered_schema_info(inner_type):
|
||||
# We'll need to look this up later, but for now store the type
|
||||
streaming_model = inner_type
|
||||
elif hasattr(arg, "model_json_schema"):
|
||||
|
|
@ -427,17 +420,28 @@ def _find_models_for_endpoint(
|
|||
try:
|
||||
from fastapi import Response as FastAPIResponse
|
||||
except ImportError:
|
||||
FastAPIResponse = None
|
||||
fastapi_response_cls = None
|
||||
else:
|
||||
fastapi_response_cls = FastAPIResponse
|
||||
try:
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
except ImportError:
|
||||
StarletteResponse = None
|
||||
starlette_response_cls = None
|
||||
else:
|
||||
starlette_response_cls = StarletteResponse
|
||||
|
||||
response_types = tuple(t for t in (FastAPIResponse, StarletteResponse) if t is not None)
|
||||
response_types = tuple(t for t in (fastapi_response_cls, starlette_response_cls) if t is not None)
|
||||
if response_types and any(return_annotation is t for t in response_types):
|
||||
response_schema_name = "Response"
|
||||
|
||||
return request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name
|
||||
return (
|
||||
request_model,
|
||||
response_model,
|
||||
query_parameters,
|
||||
file_form_params,
|
||||
streaming_response_model,
|
||||
response_schema_name,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
|
|
@ -465,9 +469,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
file_form_params,
|
||||
streaming_response_model,
|
||||
response_schema_name,
|
||||
) = (
|
||||
_find_models_for_endpoint(webmethod, api, name, is_post_put)
|
||||
)
|
||||
) = _find_models_for_endpoint(webmethod, api, name, is_post_put)
|
||||
operation_description = _extract_operation_description_from_docstring(api, name)
|
||||
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
|
||||
|
||||
|
|
@ -479,6 +481,17 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
key = (fastapi_path, method.upper())
|
||||
_extra_body_fields[key] = extra_body_params
|
||||
|
||||
if is_post_put and not request_model and not file_form_params and query_parameters:
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, name, primary_method, query_parameters, use_any=False
|
||||
)
|
||||
if not request_model:
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, name, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
|
||||
)
|
||||
if request_model:
|
||||
query_parameters = []
|
||||
|
||||
if file_form_params and is_post_put:
|
||||
signature_params = list(file_form_params)
|
||||
param_annotations = {param.name: param.annotation for param in file_form_params}
|
||||
|
|
@ -503,12 +516,16 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
endpoint_func = file_form_endpoint
|
||||
elif request_model and response_model:
|
||||
endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
|
||||
elif request_model:
|
||||
endpoint_func = _create_endpoint_with_request_model(request_model, None, operation_description)
|
||||
elif response_model and query_parameters:
|
||||
if is_post_put:
|
||||
request_model = _create_dynamic_request_model(api, webmethod, primary_method, query_parameters, use_any=False)
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, name, primary_method, query_parameters, use_any=False
|
||||
)
|
||||
if not request_model:
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
|
||||
api, webmethod, name, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
|
||||
)
|
||||
|
||||
if request_model:
|
||||
|
|
@ -600,10 +617,8 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
streaming_schema_name = None
|
||||
# Check if it's a registered schema first (before checking __name__)
|
||||
# because registered schemas might be Annotated types
|
||||
from llama_stack_api.schema_utils import _registered_schemas
|
||||
|
||||
if streaming_response_model in _registered_schemas:
|
||||
streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
|
||||
if schema_info := get_registered_schema_info(streaming_response_model):
|
||||
streaming_schema_name = schema_info.name
|
||||
elif hasattr(streaming_response_model, "__name__"):
|
||||
streaming_schema_name = streaming_response_model.__name__
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue