schema naming cleanup, should be much closer

This commit is contained in:
Ashwin Bharambe 2025-11-14 13:07:34 -08:00
parent 69e1176ff8
commit 5293b4e5e9
11 changed files with 20459 additions and 12557 deletions

View file

@ -12,7 +12,6 @@ These lists help the new generator match the previous ordering so that diffs
remain readable while we debug schema content regressions. Remove once stable.
"""
# TODO: remove once generator output stabilizes
LEGACY_PATH_ORDER = ['/v1/batches',
'/v1/batches/{batch_id}',
'/v1/batches/{batch_id}/cancel',
@ -364,6 +363,57 @@ LEGACY_SCHEMA_ORDER = ['Error',
LEGACY_RESPONSE_ORDER = ['BadRequest400', 'TooManyRequests429', 'InternalServerError500', 'DefaultError']
LEGACY_TAGS = [{'description': 'APIs for creating and interacting with agentic systems.',
'name': 'Agents',
'x-displayName': 'Agents'},
{'description': 'The API is designed to allow use of openai client libraries for seamless integration.\n'
'\n'
'This API provides the following extensions:\n'
' - idempotent batch creation\n'
'\n'
'Note: This API is currently under active development and may undergo changes.',
'name': 'Batches',
'x-displayName': 'The Batches API enables efficient processing of multiple requests in a single operation, '
'particularly useful for processing large datasets, batch evaluation workflows, and cost-effective '
'inference at scale.'},
{'description': '', 'name': 'Benchmarks'},
{'description': 'Protocol for conversation management operations.',
'name': 'Conversations',
'x-displayName': 'Conversations'},
{'description': '', 'name': 'DatasetIO'},
{'description': '', 'name': 'Datasets'},
{'description': 'Llama Stack Evaluation API for running evaluations on model and agent candidates.',
'name': 'Eval',
'x-displayName': 'Evaluations'},
{'description': 'This API is used to upload documents that can be used with other Llama Stack APIs.',
'name': 'Files',
'x-displayName': 'Files'},
{'description': 'Llama Stack Inference API for generating completions, chat completions, and embeddings.\n'
'\n'
'This API provides the raw interface to the underlying models. Three kinds of models are supported:\n'
'- LLM models: these models generate "raw" and "chat" (conversational) completions.\n'
'- Embedding models: these models generate embeddings to be used for semantic search.\n'
'- Rerank models: these models reorder the documents based on their relevance to a query.',
'name': 'Inference',
'x-displayName': 'Inference'},
{'description': 'APIs for inspecting the Llama Stack service, including health status, available API routes with '
'methods and implementing providers.',
'name': 'Inspect',
'x-displayName': 'Inspect'},
{'description': '', 'name': 'Models'},
{'description': '', 'name': 'PostTraining (Coming Soon)'},
{'description': 'Protocol for prompt management operations.', 'name': 'Prompts', 'x-displayName': 'Prompts'},
{'description': 'Providers API for inspecting, listing, and modifying providers and their configurations.',
'name': 'Providers',
'x-displayName': 'Providers'},
{'description': 'OpenAI-compatible Moderations API.', 'name': 'Safety', 'x-displayName': 'Safety'},
{'description': '', 'name': 'Scoring'},
{'description': '', 'name': 'ScoringFunctions'},
{'description': '', 'name': 'Shields'},
{'description': '', 'name': 'ToolGroups'},
{'description': '', 'name': 'ToolRuntime'},
{'description': '', 'name': 'VectorIO'}]
LEGACY_TAG_ORDER = ['Agents',
'Batches',
'Benchmarks',
@ -408,3 +458,16 @@ LEGACY_TAG_GROUPS = [{'name': 'Operations',
'ToolGroups',
'ToolRuntime',
'VectorIO']}]
LEGACY_SECURITY = [{'Default': []}]
LEGACY_OPERATION_KEYS = [
'responses',
'tags',
'summary',
'description',
'operationId',
'parameters',
'requestBody',
'deprecated',
]

View file

@ -36,7 +36,7 @@ def _get_protocol_method(api: Api, method_name: str) -> Any | None:
if _protocol_methods_cache is None:
_protocol_methods_cache = {}
protocols = api_protocol_map()
from llama_stack.apis.tools import SpecialToolGroup, ToolRuntime
from llama_stack_api.tools import SpecialToolGroup, ToolRuntime
toolgroup_protocols = {
SpecialToolGroup.rag_tool: ToolRuntime,

View file

@ -12,16 +12,45 @@ import inspect
import re
import types
import typing
import uuid
from typing import Annotated, Any, get_args, get_origin
from fastapi import FastAPI
from pydantic import Field, create_model
from llama_stack.log import get_logger
from llama_stack_api import Api
from . import app as app_module
from .state import _dynamic_models, _extra_body_fields
from .state import _extra_body_fields, register_dynamic_model
logger = get_logger(name=__name__, category="core")
def _to_pascal_case(segment: str) -> str:
tokens = re.findall(r"[A-Za-z]+|\d+", segment)
return "".join(token.capitalize() for token in tokens if token)
def _compose_request_model_name(webmethod, http_method: str, variant: str | None = None) -> str:
segments = []
level = (webmethod.level or "").lower()
if level and level != "v1":
segments.append(_to_pascal_case(str(webmethod.level)))
for part in filter(None, webmethod.route.split("/")):
lower_part = part.lower()
if lower_part in {"v1", "v1alpha", "v1beta"}:
continue
if part.startswith("{"):
param = part[1:].split(":", 1)[0]
segments.append(f"By{_to_pascal_case(param)}")
else:
segments.append(_to_pascal_case(part))
if not segments:
segments.append("Root")
base_name = "".join(segments) + http_method.title() + "Request"
if variant:
base_name = f"{base_name}{variant}"
return base_name
def _extract_path_parameters(path: str) -> list[dict[str, Any]]:
@ -99,21 +128,21 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_
def _create_dynamic_request_model(
webmethod, query_parameters: list[tuple[str, type, Any]], use_any: bool = False, add_uuid: bool = False
api: Api,
webmethod,
http_method: str,
query_parameters: list[tuple[str, type, Any]],
use_any: bool = False,
variant_suffix: str | None = None,
) -> type | None:
"""Create a dynamic Pydantic model for request body."""
try:
field_definitions = _build_field_definitions(query_parameters, use_any)
if not field_definitions:
return None
clean_route = webmethod.route.replace("/", "_").replace("{", "").replace("}", "").replace("-", "_")
model_name = f"{clean_route}_Request"
if add_uuid:
model_name = f"{model_name}_{uuid.uuid4().hex[:8]}"
model_name = _compose_request_model_name(webmethod, http_method, variant_suffix)
request_model = create_model(model_name, **field_definitions)
_dynamic_models.append(request_model)
return request_model
return register_dynamic_model(model_name, request_model)
except Exception:
return None
@ -190,7 +219,7 @@ def _is_file_or_form_param(param_type: Any) -> bool:
def _is_extra_body_field(metadata_item: Any) -> bool:
"""Check if a metadata item is an ExtraBodyField instance."""
from llama_stack.schema_utils import ExtraBodyField
from llama_stack_api.schema_utils import ExtraBodyField
return isinstance(metadata_item, ExtraBodyField)
@ -232,7 +261,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
streaming_model = inner_type
else:
# Might be a registered schema - check if it's registered
from llama_stack.schema_utils import _registered_schemas
from llama_stack_api.schema_utils import _registered_schemas
if inner_type in _registered_schemas:
# We'll need to look this up later, but for now store the type
@ -247,7 +276,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
def _find_models_for_endpoint(
webmethod, api: Api, method_name: str, is_post_put: bool = False
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None]:
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None, str | None]:
"""
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
This uses the protocol function to determine the correct models dynamically.
@ -259,16 +288,18 @@ def _find_models_for_endpoint(
is_post_put: Whether this is a POST, PUT, or PATCH request (GET requests should never have request bodies)
Returns:
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model)
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name)
where query_parameters is a list of (name, type, default_value) tuples
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
and streaming_response_model is the model for streaming responses (AsyncIterator content)
"""
route_descriptor = f"{webmethod.method or 'UNKNOWN'} {webmethod.route}"
try:
# Get the function from the protocol
func = app_module._get_protocol_method(api, method_name)
if not func:
return None, None, [], [], None
logger.warning("No protocol method for %s.%s (%s)", api, method_name, route_descriptor)
return None, None, [], [], None, None
# Analyze the function signature
sig = inspect.signature(func)
@ -279,6 +310,7 @@ def _find_models_for_endpoint(
file_form_params = []
path_params = set()
extra_body_params = []
response_schema_name = None
# Extract path parameters from the route
if webmethod and hasattr(webmethod, "route"):
@ -391,23 +423,49 @@ def _find_models_for_endpoint(
elif origin is not None and (origin is types.UnionType or origin is typing.Union):
# Handle union types - extract both non-streaming and streaming models
response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
else:
try:
from fastapi import Response as FastAPIResponse
except ImportError:
FastAPIResponse = None
try:
from starlette.responses import Response as StarletteResponse
except ImportError:
StarletteResponse = None
return request_model, response_model, query_parameters, file_form_params, streaming_response_model
response_types = tuple(t for t in (FastAPIResponse, StarletteResponse) if t is not None)
if response_types and any(return_annotation is t for t in response_types):
response_schema_name = "Response"
except Exception:
# If we can't analyze the function signature, return None
return None, None, [], [], None
return request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name
except Exception as exc:
logger.warning(
"Failed to analyze endpoint %s.%s (%s): %s", api, method_name, route_descriptor, exc, exc_info=True
)
return None, None, [], [], None, None
def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
"""Create a FastAPI endpoint from a discovered route and webmethod."""
path = route.path
methods = route.methods
raw_methods = route.methods or set()
method_list = sorted({method.upper() for method in raw_methods if method and method.upper() != "HEAD"})
if not method_list:
method_list = ["GET"]
primary_method = method_list[0]
name = route.name
fastapi_path = path.replace("{", "{").replace("}", "}")
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
is_post_put = any(method in ["POST", "PUT", "PATCH"] for method in method_list)
request_model, response_model, query_parameters, file_form_params, streaming_response_model = (
(
request_model,
response_model,
query_parameters,
file_form_params,
streaming_response_model,
response_schema_name,
) = (
_find_models_for_endpoint(webmethod, api, name, is_post_put)
)
operation_description = _extract_operation_description_from_docstring(api, name)
@ -417,7 +475,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
func = app_module._get_protocol_method(api, name)
extra_body_params = getattr(func, "_extra_body_params", []) if func else []
if extra_body_params:
for method in methods:
for method in method_list:
key = (fastapi_path, method.upper())
_extra_body_fields[key] = extra_body_params
@ -447,12 +505,11 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
elif response_model and query_parameters:
if is_post_put:
# Try creating request model with type preservation, fallback to Any, then minimal
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=False)
request_model = _create_dynamic_request_model(api, webmethod, primary_method, query_parameters, use_any=False)
if not request_model:
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True)
if not request_model:
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True, add_uuid=True)
request_model = _create_dynamic_request_model(
api, webmethod, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
)
if request_model:
endpoint_func = _create_endpoint_with_request_model(
@ -532,16 +589,18 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
endpoint_func = no_params_endpoint
# Build response content with both application/json and text/event-stream if streaming
response_content = {}
response_content: dict[str, Any] = {}
if response_model:
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
elif response_schema_name:
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_schema_name}"}}
if streaming_response_model:
# Get the schema name for the streaming model
# It might be a registered schema or a Pydantic model
streaming_schema_name = None
# Check if it's a registered schema first (before checking __name__)
# because registered schemas might be Annotated types
from llama_stack.schema_utils import _registered_schemas
from llama_stack_api.schema_utils import _registered_schemas
if streaming_response_model in _registered_schemas:
streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
@ -554,9 +613,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
}
# If no content types, use empty schema
if not response_content:
response_content["application/json"] = {"schema": {}}
# Add the endpoint to the FastAPI app
is_deprecated = webmethod.deprecated or False
route_kwargs = {
@ -564,16 +620,16 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
"tags": [_get_tag_from_api(api)],
"deprecated": is_deprecated,
"responses": {
200: {
"description": response_description,
"content": response_content,
},
400: {"$ref": "#/components/responses/BadRequest400"},
429: {"$ref": "#/components/responses/TooManyRequests429"},
500: {"$ref": "#/components/responses/InternalServerError500"},
"default": {"$ref": "#/components/responses/DefaultError"},
},
}
success_response: dict[str, Any] = {"description": response_description}
if response_content:
success_response["content"] = response_content
route_kwargs["responses"][200] = success_response
# FastAPI needs response_model parameter to properly generate OpenAPI spec
# Use the non-streaming response model if available
@ -581,6 +637,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
route_kwargs["response_model"] = response_model
method_map = {"GET": app.get, "POST": app.post, "PUT": app.put, "DELETE": app.delete, "PATCH": app.patch}
for method in methods:
if handler := method_map.get(method.upper()):
for method in method_list:
if handler := method_map.get(method):
handler(fastapi_path, **route_kwargs)(endpoint_func)

View file

@ -16,7 +16,7 @@ from typing import Any
import yaml
from fastapi.openapi.utils import get_openapi
from . import app, schema_collection, schema_filtering, schema_transforms
from . import app, schema_collection, schema_filtering, schema_transforms, state
def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
@ -29,6 +29,7 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
Returns:
The generated OpenAPI specification as a dictionary
"""
state.reset_generator_state()
# Create the FastAPI app
fastapi_app = app.create_llama_stack_app()
@ -143,7 +144,7 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
schema_transforms._fix_schema_issues(schema)
schema_transforms._apply_legacy_sorting(schema)
print("\n🔍 Validating generated schemas...")
print("\nValidating generated schemas...")
failed_schemas = [
name for schema, name in schemas_to_validate if not schema_transforms.validate_openapi_schema(schema, name)
]
@ -186,20 +187,20 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
# Write the modified YAML back
schema_transforms._write_yaml_file(yaml_path, yaml_data)
print(f"Generated YAML (stable): {yaml_path}")
print(f"Generated YAML (stable): {yaml_path}")
experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml"
schema_transforms._write_yaml_file(experimental_yaml_path, experimental_schema)
print(f"Generated YAML (experimental): {experimental_yaml_path}")
print(f"Generated YAML (experimental): {experimental_yaml_path}")
deprecated_yaml_path = output_path / "deprecated-llama-stack-spec.yaml"
schema_transforms._write_yaml_file(deprecated_yaml_path, deprecated_schema)
print(f"Generated YAML (deprecated): {deprecated_yaml_path}")
print(f"Generated YAML (deprecated): {deprecated_yaml_path}")
# Generate combined (stainless) spec
stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml"
schema_transforms._write_yaml_file(stainless_yaml_path, combined_schema)
print(f"Generated YAML (stainless/combined): {stainless_yaml_path}")
print(f"Generated YAML (stainless/combined): {stainless_yaml_path}")
return stable_schema
@ -213,25 +214,25 @@ def main():
args = parser.parse_args()
print("🚀 Generating OpenAPI specification using FastAPI...")
print(f"📁 Output directory: {args.output_dir}")
print("Generating OpenAPI specification using FastAPI...")
print(f"Output directory: {args.output_dir}")
try:
openapi_schema = generate_openapi_spec(output_dir=args.output_dir)
print("\nOpenAPI specification generated successfully!")
print(f"📊 Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}")
print(f"🛣️ Paths: {len(openapi_schema.get('paths', {}))}")
print("\nOpenAPI specification generated successfully!")
print(f"Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}")
print(f"Paths: {len(openapi_schema.get('paths', {}))}")
operation_count = sum(
1
for path_info in openapi_schema.get("paths", {}).values()
for method in ["get", "post", "put", "delete", "patch"]
if method in path_info
)
print(f"🔧 Operations: {operation_count}")
print(f"Operations: {operation_count}")
except Exception as e:
print(f"Error generating OpenAPI specification: {e}")
print(f"Error generating OpenAPI specification: {e}")
raise

View file

@ -22,6 +22,9 @@ from ._legacy_order import (
LEGACY_PATH_ORDER,
LEGACY_RESPONSE_ORDER,
LEGACY_SCHEMA_ORDER,
LEGACY_OPERATION_KEYS,
LEGACY_SECURITY,
LEGACY_TAGS,
LEGACY_TAG_GROUPS,
LEGACY_TAG_ORDER,
)
@ -121,7 +124,7 @@ def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
openapi_schema["components"]["responses"] = {}
try:
from llama_stack.apis.datatypes import Error
from llama_stack_api.datatypes import Error
schema_collection._ensure_components_schemas(openapi_schema)
if "Error" not in openapi_schema["components"]["schemas"]:
@ -129,6 +132,10 @@ def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
except ImportError:
pass
schema_collection._ensure_components_schemas(openapi_schema)
if "Response" not in openapi_schema["components"]["schemas"]:
openapi_schema["components"]["schemas"]["Response"] = {"title": "Response", "type": "object"}
# Define standard HTTP error responses
error_responses = {
400: {
@ -848,6 +855,20 @@ def _apply_legacy_sorting(openapi_schema: dict[str, Any]) -> dict[str, Any]:
paths = openapi_schema.get("paths")
if isinstance(paths, dict):
openapi_schema["paths"] = order_mapping(paths, LEGACY_PATH_ORDER)
for path, path_item in openapi_schema["paths"].items():
if not isinstance(path_item, dict):
continue
ordered_path_item = OrderedDict()
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
if method in path_item:
ordered_path_item[method] = order_mapping(path_item[method], LEGACY_OPERATION_KEYS)
for key, value in path_item.items():
if key not in ordered_path_item:
if isinstance(value, dict) and key.lower() in {"get", "post", "put", "delete", "patch", "head", "options"}:
ordered_path_item[key] = order_mapping(value, LEGACY_OPERATION_KEYS)
else:
ordered_path_item[key] = value
openapi_schema["paths"][path] = ordered_path_item
components = openapi_schema.setdefault("components", {})
schemas = components.get("schemas")
@ -857,30 +878,14 @@ def _apply_legacy_sorting(openapi_schema: dict[str, Any]) -> dict[str, Any]:
if isinstance(responses, dict):
components["responses"] = order_mapping(responses, LEGACY_RESPONSE_ORDER)
tags = openapi_schema.get("tags")
if isinstance(tags, list):
tag_priority = {name: idx for idx, name in enumerate(LEGACY_TAG_ORDER)}
if LEGACY_TAGS:
openapi_schema["tags"] = LEGACY_TAGS
def tag_sort(tag_obj: dict[str, Any]) -> tuple[int, int | str]:
name = tag_obj.get("name", "")
if name in tag_priority:
return (0, tag_priority[name])
return (1, name)
if LEGACY_TAG_GROUPS:
openapi_schema["x-tagGroups"] = LEGACY_TAG_GROUPS
openapi_schema["tags"] = sorted(tags, key=tag_sort)
tag_groups = openapi_schema.get("x-tagGroups")
if isinstance(tag_groups, list) and LEGACY_TAG_GROUPS:
legacy_tags = LEGACY_TAG_GROUPS[0].get("tags", [])
tag_priority = {name: idx for idx, name in enumerate(legacy_tags)}
for group in tag_groups:
group_tags = group.get("tags")
if isinstance(group_tags, list):
group["tags"] = sorted(
group_tags,
key=lambda name: (0, tag_priority[name]) if name in tag_priority else (1, name),
)
openapi_schema["x-tagGroups"] = tag_groups
if LEGACY_SECURITY:
openapi_schema["security"] = LEGACY_SECURITY
return openapi_schema
@ -914,12 +919,11 @@ def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI
"""
try:
validate_spec(schema)
print(f"{schema_name} is valid")
print(f"{schema_name} is valid")
return True
except OpenAPISpecValidatorError as e:
print(f"{schema_name} validation failed:")
print(f" {e}")
print(f"{schema_name} validation failed: {e}")
return False
except Exception as e:
print(f"{schema_name} validation error: {e}")
print(f"{schema_name} validation error: {e}")
return False

View file

@ -14,6 +14,7 @@ from llama_stack_api import Api
# Global list to store dynamic models created during endpoint generation
_dynamic_models: list[Any] = []
_dynamic_model_registry: dict[str, type] = {}
# Cache for protocol methods to avoid repeated lookups
_protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
@ -21,3 +22,20 @@ _protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
# Global dict to store extra body field information by endpoint
# Key: (path, method) tuple, Value: list of (param_name, param_type, description) tuples
_extra_body_fields: dict[tuple[str, str], list[tuple[str, type, str | None]]] = {}
def register_dynamic_model(name: str, model: type) -> type:
"""Register and deduplicate dynamically generated request models."""
existing = _dynamic_model_registry.get(name)
if existing is not None:
return existing
_dynamic_model_registry[name] = model
_dynamic_models.append(model)
return model
def reset_generator_state() -> None:
"""Clear per-run caches so repeated generations stay deterministic."""
_dynamic_models.clear()
_dynamic_model_registry.clear()
_extra_body_fields.clear()