schema naming cleanup, should be much closer

This commit is contained in:
Ashwin Bharambe 2025-11-14 13:07:34 -08:00
parent 69e1176ff8
commit 5293b4e5e9
11 changed files with 20459 additions and 12557 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -12,7 +12,6 @@ These lists help the new generator match the previous ordering so that diffs
remain readable while we debug schema content regressions. Remove once stable. remain readable while we debug schema content regressions. Remove once stable.
""" """
# TODO: remove once generator output stabilizes
LEGACY_PATH_ORDER = ['/v1/batches', LEGACY_PATH_ORDER = ['/v1/batches',
'/v1/batches/{batch_id}', '/v1/batches/{batch_id}',
'/v1/batches/{batch_id}/cancel', '/v1/batches/{batch_id}/cancel',
@ -364,6 +363,57 @@ LEGACY_SCHEMA_ORDER = ['Error',
LEGACY_RESPONSE_ORDER = ['BadRequest400', 'TooManyRequests429', 'InternalServerError500', 'DefaultError'] LEGACY_RESPONSE_ORDER = ['BadRequest400', 'TooManyRequests429', 'InternalServerError500', 'DefaultError']
LEGACY_TAGS = [{'description': 'APIs for creating and interacting with agentic systems.',
'name': 'Agents',
'x-displayName': 'Agents'},
{'description': 'The API is designed to allow use of openai client libraries for seamless integration.\n'
'\n'
'This API provides the following extensions:\n'
' - idempotent batch creation\n'
'\n'
'Note: This API is currently under active development and may undergo changes.',
'name': 'Batches',
'x-displayName': 'The Batches API enables efficient processing of multiple requests in a single operation, '
'particularly useful for processing large datasets, batch evaluation workflows, and cost-effective '
'inference at scale.'},
{'description': '', 'name': 'Benchmarks'},
{'description': 'Protocol for conversation management operations.',
'name': 'Conversations',
'x-displayName': 'Conversations'},
{'description': '', 'name': 'DatasetIO'},
{'description': '', 'name': 'Datasets'},
{'description': 'Llama Stack Evaluation API for running evaluations on model and agent candidates.',
'name': 'Eval',
'x-displayName': 'Evaluations'},
{'description': 'This API is used to upload documents that can be used with other Llama Stack APIs.',
'name': 'Files',
'x-displayName': 'Files'},
{'description': 'Llama Stack Inference API for generating completions, chat completions, and embeddings.\n'
'\n'
'This API provides the raw interface to the underlying models. Three kinds of models are supported:\n'
'- LLM models: these models generate "raw" and "chat" (conversational) completions.\n'
'- Embedding models: these models generate embeddings to be used for semantic search.\n'
'- Rerank models: these models reorder the documents based on their relevance to a query.',
'name': 'Inference',
'x-displayName': 'Inference'},
{'description': 'APIs for inspecting the Llama Stack service, including health status, available API routes with '
'methods and implementing providers.',
'name': 'Inspect',
'x-displayName': 'Inspect'},
{'description': '', 'name': 'Models'},
{'description': '', 'name': 'PostTraining (Coming Soon)'},
{'description': 'Protocol for prompt management operations.', 'name': 'Prompts', 'x-displayName': 'Prompts'},
{'description': 'Providers API for inspecting, listing, and modifying providers and their configurations.',
'name': 'Providers',
'x-displayName': 'Providers'},
{'description': 'OpenAI-compatible Moderations API.', 'name': 'Safety', 'x-displayName': 'Safety'},
{'description': '', 'name': 'Scoring'},
{'description': '', 'name': 'ScoringFunctions'},
{'description': '', 'name': 'Shields'},
{'description': '', 'name': 'ToolGroups'},
{'description': '', 'name': 'ToolRuntime'},
{'description': '', 'name': 'VectorIO'}]
LEGACY_TAG_ORDER = ['Agents', LEGACY_TAG_ORDER = ['Agents',
'Batches', 'Batches',
'Benchmarks', 'Benchmarks',
@ -408,3 +458,16 @@ LEGACY_TAG_GROUPS = [{'name': 'Operations',
'ToolGroups', 'ToolGroups',
'ToolRuntime', 'ToolRuntime',
'VectorIO']}] 'VectorIO']}]
LEGACY_SECURITY = [{'Default': []}]
LEGACY_OPERATION_KEYS = [
'responses',
'tags',
'summary',
'description',
'operationId',
'parameters',
'requestBody',
'deprecated',
]

View file

@ -36,7 +36,7 @@ def _get_protocol_method(api: Api, method_name: str) -> Any | None:
if _protocol_methods_cache is None: if _protocol_methods_cache is None:
_protocol_methods_cache = {} _protocol_methods_cache = {}
protocols = api_protocol_map() protocols = api_protocol_map()
from llama_stack.apis.tools import SpecialToolGroup, ToolRuntime from llama_stack_api.tools import SpecialToolGroup, ToolRuntime
toolgroup_protocols = { toolgroup_protocols = {
SpecialToolGroup.rag_tool: ToolRuntime, SpecialToolGroup.rag_tool: ToolRuntime,

View file

@ -12,16 +12,45 @@ import inspect
import re import re
import types import types
import typing import typing
import uuid
from typing import Annotated, Any, get_args, get_origin from typing import Annotated, Any, get_args, get_origin
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import Field, create_model from pydantic import Field, create_model
from llama_stack.log import get_logger
from llama_stack_api import Api from llama_stack_api import Api
from . import app as app_module from . import app as app_module
from .state import _dynamic_models, _extra_body_fields from .state import _extra_body_fields, register_dynamic_model
logger = get_logger(name=__name__, category="core")
def _to_pascal_case(segment: str) -> str:
tokens = re.findall(r"[A-Za-z]+|\d+", segment)
return "".join(token.capitalize() for token in tokens if token)
def _compose_request_model_name(webmethod, http_method: str, variant: str | None = None) -> str:
segments = []
level = (webmethod.level or "").lower()
if level and level != "v1":
segments.append(_to_pascal_case(str(webmethod.level)))
for part in filter(None, webmethod.route.split("/")):
lower_part = part.lower()
if lower_part in {"v1", "v1alpha", "v1beta"}:
continue
if part.startswith("{"):
param = part[1:].split(":", 1)[0]
segments.append(f"By{_to_pascal_case(param)}")
else:
segments.append(_to_pascal_case(part))
if not segments:
segments.append("Root")
base_name = "".join(segments) + http_method.title() + "Request"
if variant:
base_name = f"{base_name}{variant}"
return base_name
def _extract_path_parameters(path: str) -> list[dict[str, Any]]: def _extract_path_parameters(path: str) -> list[dict[str, Any]]:
@ -99,21 +128,21 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_
def _create_dynamic_request_model( def _create_dynamic_request_model(
webmethod, query_parameters: list[tuple[str, type, Any]], use_any: bool = False, add_uuid: bool = False api: Api,
webmethod,
http_method: str,
query_parameters: list[tuple[str, type, Any]],
use_any: bool = False,
variant_suffix: str | None = None,
) -> type | None: ) -> type | None:
"""Create a dynamic Pydantic model for request body.""" """Create a dynamic Pydantic model for request body."""
try: try:
field_definitions = _build_field_definitions(query_parameters, use_any) field_definitions = _build_field_definitions(query_parameters, use_any)
if not field_definitions: if not field_definitions:
return None return None
clean_route = webmethod.route.replace("/", "_").replace("{", "").replace("}", "").replace("-", "_") model_name = _compose_request_model_name(webmethod, http_method, variant_suffix)
model_name = f"{clean_route}_Request"
if add_uuid:
model_name = f"{model_name}_{uuid.uuid4().hex[:8]}"
request_model = create_model(model_name, **field_definitions) request_model = create_model(model_name, **field_definitions)
_dynamic_models.append(request_model) return register_dynamic_model(model_name, request_model)
return request_model
except Exception: except Exception:
return None return None
@ -190,7 +219,7 @@ def _is_file_or_form_param(param_type: Any) -> bool:
def _is_extra_body_field(metadata_item: Any) -> bool: def _is_extra_body_field(metadata_item: Any) -> bool:
"""Check if a metadata item is an ExtraBodyField instance.""" """Check if a metadata item is an ExtraBodyField instance."""
from llama_stack.schema_utils import ExtraBodyField from llama_stack_api.schema_utils import ExtraBodyField
return isinstance(metadata_item, ExtraBodyField) return isinstance(metadata_item, ExtraBodyField)
@ -232,7 +261,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
streaming_model = inner_type streaming_model = inner_type
else: else:
# Might be a registered schema - check if it's registered # Might be a registered schema - check if it's registered
from llama_stack.schema_utils import _registered_schemas from llama_stack_api.schema_utils import _registered_schemas
if inner_type in _registered_schemas: if inner_type in _registered_schemas:
# We'll need to look this up later, but for now store the type # We'll need to look this up later, but for now store the type
@ -247,7 +276,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
def _find_models_for_endpoint( def _find_models_for_endpoint(
webmethod, api: Api, method_name: str, is_post_put: bool = False webmethod, api: Api, method_name: str, is_post_put: bool = False
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None]: ) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None, str | None]:
""" """
Find appropriate request and response models for an endpoint by analyzing the actual function signature. Find appropriate request and response models for an endpoint by analyzing the actual function signature.
This uses the protocol function to determine the correct models dynamically. This uses the protocol function to determine the correct models dynamically.
@ -259,16 +288,18 @@ def _find_models_for_endpoint(
is_post_put: Whether this is a POST, PUT, or PATCH request (GET requests should never have request bodies) is_post_put: Whether this is a POST, PUT, or PATCH request (GET requests should never have request bodies)
Returns: Returns:
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model) tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name)
where query_parameters is a list of (name, type, default_value) tuples where query_parameters is a list of (name, type, default_value) tuples
and file_form_params is a list of inspect.Parameter objects for File()/Form() params and file_form_params is a list of inspect.Parameter objects for File()/Form() params
and streaming_response_model is the model for streaming responses (AsyncIterator content) and streaming_response_model is the model for streaming responses (AsyncIterator content)
""" """
route_descriptor = f"{webmethod.method or 'UNKNOWN'} {webmethod.route}"
try: try:
# Get the function from the protocol # Get the function from the protocol
func = app_module._get_protocol_method(api, method_name) func = app_module._get_protocol_method(api, method_name)
if not func: if not func:
return None, None, [], [], None logger.warning("No protocol method for %s.%s (%s)", api, method_name, route_descriptor)
return None, None, [], [], None, None
# Analyze the function signature # Analyze the function signature
sig = inspect.signature(func) sig = inspect.signature(func)
@ -279,6 +310,7 @@ def _find_models_for_endpoint(
file_form_params = [] file_form_params = []
path_params = set() path_params = set()
extra_body_params = [] extra_body_params = []
response_schema_name = None
# Extract path parameters from the route # Extract path parameters from the route
if webmethod and hasattr(webmethod, "route"): if webmethod and hasattr(webmethod, "route"):
@ -391,23 +423,49 @@ def _find_models_for_endpoint(
elif origin is not None and (origin is types.UnionType or origin is typing.Union): elif origin is not None and (origin is types.UnionType or origin is typing.Union):
# Handle union types - extract both non-streaming and streaming models # Handle union types - extract both non-streaming and streaming models
response_model, streaming_response_model = _extract_response_models_from_union(return_annotation) response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
else:
try:
from fastapi import Response as FastAPIResponse
except ImportError:
FastAPIResponse = None
try:
from starlette.responses import Response as StarletteResponse
except ImportError:
StarletteResponse = None
return request_model, response_model, query_parameters, file_form_params, streaming_response_model response_types = tuple(t for t in (FastAPIResponse, StarletteResponse) if t is not None)
if response_types and any(return_annotation is t for t in response_types):
response_schema_name = "Response"
except Exception: return request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name
# If we can't analyze the function signature, return None
return None, None, [], [], None except Exception as exc:
logger.warning(
"Failed to analyze endpoint %s.%s (%s): %s", api, method_name, route_descriptor, exc, exc_info=True
)
return None, None, [], [], None, None
def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api): def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
"""Create a FastAPI endpoint from a discovered route and webmethod.""" """Create a FastAPI endpoint from a discovered route and webmethod."""
path = route.path path = route.path
methods = route.methods raw_methods = route.methods or set()
method_list = sorted({method.upper() for method in raw_methods if method and method.upper() != "HEAD"})
if not method_list:
method_list = ["GET"]
primary_method = method_list[0]
name = route.name name = route.name
fastapi_path = path.replace("{", "{").replace("}", "}") fastapi_path = path.replace("{", "{").replace("}", "}")
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods) is_post_put = any(method in ["POST", "PUT", "PATCH"] for method in method_list)
request_model, response_model, query_parameters, file_form_params, streaming_response_model = ( (
request_model,
response_model,
query_parameters,
file_form_params,
streaming_response_model,
response_schema_name,
) = (
_find_models_for_endpoint(webmethod, api, name, is_post_put) _find_models_for_endpoint(webmethod, api, name, is_post_put)
) )
operation_description = _extract_operation_description_from_docstring(api, name) operation_description = _extract_operation_description_from_docstring(api, name)
@ -417,7 +475,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
func = app_module._get_protocol_method(api, name) func = app_module._get_protocol_method(api, name)
extra_body_params = getattr(func, "_extra_body_params", []) if func else [] extra_body_params = getattr(func, "_extra_body_params", []) if func else []
if extra_body_params: if extra_body_params:
for method in methods: for method in method_list:
key = (fastapi_path, method.upper()) key = (fastapi_path, method.upper())
_extra_body_fields[key] = extra_body_params _extra_body_fields[key] = extra_body_params
@ -447,12 +505,11 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description) endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
elif response_model and query_parameters: elif response_model and query_parameters:
if is_post_put: if is_post_put:
# Try creating request model with type preservation, fallback to Any, then minimal request_model = _create_dynamic_request_model(api, webmethod, primary_method, query_parameters, use_any=False)
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=False)
if not request_model: if not request_model:
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True) request_model = _create_dynamic_request_model(
if not request_model: api, webmethod, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True, add_uuid=True) )
if request_model: if request_model:
endpoint_func = _create_endpoint_with_request_model( endpoint_func = _create_endpoint_with_request_model(
@ -532,16 +589,18 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
endpoint_func = no_params_endpoint endpoint_func = no_params_endpoint
# Build response content with both application/json and text/event-stream if streaming # Build response content with both application/json and text/event-stream if streaming
response_content = {} response_content: dict[str, Any] = {}
if response_model: if response_model:
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}} response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
elif response_schema_name:
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_schema_name}"}}
if streaming_response_model: if streaming_response_model:
# Get the schema name for the streaming model # Get the schema name for the streaming model
# It might be a registered schema or a Pydantic model # It might be a registered schema or a Pydantic model
streaming_schema_name = None streaming_schema_name = None
# Check if it's a registered schema first (before checking __name__) # Check if it's a registered schema first (before checking __name__)
# because registered schemas might be Annotated types # because registered schemas might be Annotated types
from llama_stack.schema_utils import _registered_schemas from llama_stack_api.schema_utils import _registered_schemas
if streaming_response_model in _registered_schemas: if streaming_response_model in _registered_schemas:
streaming_schema_name = _registered_schemas[streaming_response_model]["name"] streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
@ -554,9 +613,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
} }
# If no content types, use empty schema # If no content types, use empty schema
if not response_content:
response_content["application/json"] = {"schema": {}}
# Add the endpoint to the FastAPI app # Add the endpoint to the FastAPI app
is_deprecated = webmethod.deprecated or False is_deprecated = webmethod.deprecated or False
route_kwargs = { route_kwargs = {
@ -564,16 +620,16 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
"tags": [_get_tag_from_api(api)], "tags": [_get_tag_from_api(api)],
"deprecated": is_deprecated, "deprecated": is_deprecated,
"responses": { "responses": {
200: {
"description": response_description,
"content": response_content,
},
400: {"$ref": "#/components/responses/BadRequest400"}, 400: {"$ref": "#/components/responses/BadRequest400"},
429: {"$ref": "#/components/responses/TooManyRequests429"}, 429: {"$ref": "#/components/responses/TooManyRequests429"},
500: {"$ref": "#/components/responses/InternalServerError500"}, 500: {"$ref": "#/components/responses/InternalServerError500"},
"default": {"$ref": "#/components/responses/DefaultError"}, "default": {"$ref": "#/components/responses/DefaultError"},
}, },
} }
success_response: dict[str, Any] = {"description": response_description}
if response_content:
success_response["content"] = response_content
route_kwargs["responses"][200] = success_response
# FastAPI needs response_model parameter to properly generate OpenAPI spec # FastAPI needs response_model parameter to properly generate OpenAPI spec
# Use the non-streaming response model if available # Use the non-streaming response model if available
@ -581,6 +637,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
route_kwargs["response_model"] = response_model route_kwargs["response_model"] = response_model
method_map = {"GET": app.get, "POST": app.post, "PUT": app.put, "DELETE": app.delete, "PATCH": app.patch} method_map = {"GET": app.get, "POST": app.post, "PUT": app.put, "DELETE": app.delete, "PATCH": app.patch}
for method in methods: for method in method_list:
if handler := method_map.get(method.upper()): if handler := method_map.get(method):
handler(fastapi_path, **route_kwargs)(endpoint_func) handler(fastapi_path, **route_kwargs)(endpoint_func)

View file

@ -16,7 +16,7 @@ from typing import Any
import yaml import yaml
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from . import app, schema_collection, schema_filtering, schema_transforms from . import app, schema_collection, schema_filtering, schema_transforms, state
def generate_openapi_spec(output_dir: str) -> dict[str, Any]: def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
@ -29,6 +29,7 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
Returns: Returns:
The generated OpenAPI specification as a dictionary The generated OpenAPI specification as a dictionary
""" """
state.reset_generator_state()
# Create the FastAPI app # Create the FastAPI app
fastapi_app = app.create_llama_stack_app() fastapi_app = app.create_llama_stack_app()
@ -143,7 +144,7 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
schema_transforms._fix_schema_issues(schema) schema_transforms._fix_schema_issues(schema)
schema_transforms._apply_legacy_sorting(schema) schema_transforms._apply_legacy_sorting(schema)
print("\n🔍 Validating generated schemas...") print("\nValidating generated schemas...")
failed_schemas = [ failed_schemas = [
name for schema, name in schemas_to_validate if not schema_transforms.validate_openapi_schema(schema, name) name for schema, name in schemas_to_validate if not schema_transforms.validate_openapi_schema(schema, name)
] ]
@ -186,20 +187,20 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
# Write the modified YAML back # Write the modified YAML back
schema_transforms._write_yaml_file(yaml_path, yaml_data) schema_transforms._write_yaml_file(yaml_path, yaml_data)
print(f"Generated YAML (stable): {yaml_path}") print(f"Generated YAML (stable): {yaml_path}")
experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml" experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml"
schema_transforms._write_yaml_file(experimental_yaml_path, experimental_schema) schema_transforms._write_yaml_file(experimental_yaml_path, experimental_schema)
print(f"Generated YAML (experimental): {experimental_yaml_path}") print(f"Generated YAML (experimental): {experimental_yaml_path}")
deprecated_yaml_path = output_path / "deprecated-llama-stack-spec.yaml" deprecated_yaml_path = output_path / "deprecated-llama-stack-spec.yaml"
schema_transforms._write_yaml_file(deprecated_yaml_path, deprecated_schema) schema_transforms._write_yaml_file(deprecated_yaml_path, deprecated_schema)
print(f"Generated YAML (deprecated): {deprecated_yaml_path}") print(f"Generated YAML (deprecated): {deprecated_yaml_path}")
# Generate combined (stainless) spec # Generate combined (stainless) spec
stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml" stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml"
schema_transforms._write_yaml_file(stainless_yaml_path, combined_schema) schema_transforms._write_yaml_file(stainless_yaml_path, combined_schema)
print(f"Generated YAML (stainless/combined): {stainless_yaml_path}") print(f"Generated YAML (stainless/combined): {stainless_yaml_path}")
return stable_schema return stable_schema
@ -213,25 +214,25 @@ def main():
args = parser.parse_args() args = parser.parse_args()
print("🚀 Generating OpenAPI specification using FastAPI...") print("Generating OpenAPI specification using FastAPI...")
print(f"📁 Output directory: {args.output_dir}") print(f"Output directory: {args.output_dir}")
try: try:
openapi_schema = generate_openapi_spec(output_dir=args.output_dir) openapi_schema = generate_openapi_spec(output_dir=args.output_dir)
print("\nOpenAPI specification generated successfully!") print("\nOpenAPI specification generated successfully!")
print(f"📊 Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}") print(f"Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}")
print(f"🛣️ Paths: {len(openapi_schema.get('paths', {}))}") print(f"Paths: {len(openapi_schema.get('paths', {}))}")
operation_count = sum( operation_count = sum(
1 1
for path_info in openapi_schema.get("paths", {}).values() for path_info in openapi_schema.get("paths", {}).values()
for method in ["get", "post", "put", "delete", "patch"] for method in ["get", "post", "put", "delete", "patch"]
if method in path_info if method in path_info
) )
print(f"🔧 Operations: {operation_count}") print(f"Operations: {operation_count}")
except Exception as e: except Exception as e:
print(f"Error generating OpenAPI specification: {e}") print(f"Error generating OpenAPI specification: {e}")
raise raise

View file

@ -22,6 +22,9 @@ from ._legacy_order import (
LEGACY_PATH_ORDER, LEGACY_PATH_ORDER,
LEGACY_RESPONSE_ORDER, LEGACY_RESPONSE_ORDER,
LEGACY_SCHEMA_ORDER, LEGACY_SCHEMA_ORDER,
LEGACY_OPERATION_KEYS,
LEGACY_SECURITY,
LEGACY_TAGS,
LEGACY_TAG_GROUPS, LEGACY_TAG_GROUPS,
LEGACY_TAG_ORDER, LEGACY_TAG_ORDER,
) )
@ -121,7 +124,7 @@ def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
openapi_schema["components"]["responses"] = {} openapi_schema["components"]["responses"] = {}
try: try:
from llama_stack.apis.datatypes import Error from llama_stack_api.datatypes import Error
schema_collection._ensure_components_schemas(openapi_schema) schema_collection._ensure_components_schemas(openapi_schema)
if "Error" not in openapi_schema["components"]["schemas"]: if "Error" not in openapi_schema["components"]["schemas"]:
@ -129,6 +132,10 @@ def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
except ImportError: except ImportError:
pass pass
schema_collection._ensure_components_schemas(openapi_schema)
if "Response" not in openapi_schema["components"]["schemas"]:
openapi_schema["components"]["schemas"]["Response"] = {"title": "Response", "type": "object"}
# Define standard HTTP error responses # Define standard HTTP error responses
error_responses = { error_responses = {
400: { 400: {
@ -848,6 +855,20 @@ def _apply_legacy_sorting(openapi_schema: dict[str, Any]) -> dict[str, Any]:
paths = openapi_schema.get("paths") paths = openapi_schema.get("paths")
if isinstance(paths, dict): if isinstance(paths, dict):
openapi_schema["paths"] = order_mapping(paths, LEGACY_PATH_ORDER) openapi_schema["paths"] = order_mapping(paths, LEGACY_PATH_ORDER)
for path, path_item in openapi_schema["paths"].items():
if not isinstance(path_item, dict):
continue
ordered_path_item = OrderedDict()
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
if method in path_item:
ordered_path_item[method] = order_mapping(path_item[method], LEGACY_OPERATION_KEYS)
for key, value in path_item.items():
if key not in ordered_path_item:
if isinstance(value, dict) and key.lower() in {"get", "post", "put", "delete", "patch", "head", "options"}:
ordered_path_item[key] = order_mapping(value, LEGACY_OPERATION_KEYS)
else:
ordered_path_item[key] = value
openapi_schema["paths"][path] = ordered_path_item
components = openapi_schema.setdefault("components", {}) components = openapi_schema.setdefault("components", {})
schemas = components.get("schemas") schemas = components.get("schemas")
@ -857,30 +878,14 @@ def _apply_legacy_sorting(openapi_schema: dict[str, Any]) -> dict[str, Any]:
if isinstance(responses, dict): if isinstance(responses, dict):
components["responses"] = order_mapping(responses, LEGACY_RESPONSE_ORDER) components["responses"] = order_mapping(responses, LEGACY_RESPONSE_ORDER)
tags = openapi_schema.get("tags") if LEGACY_TAGS:
if isinstance(tags, list): openapi_schema["tags"] = LEGACY_TAGS
tag_priority = {name: idx for idx, name in enumerate(LEGACY_TAG_ORDER)}
def tag_sort(tag_obj: dict[str, Any]) -> tuple[int, int | str]: if LEGACY_TAG_GROUPS:
name = tag_obj.get("name", "") openapi_schema["x-tagGroups"] = LEGACY_TAG_GROUPS
if name in tag_priority:
return (0, tag_priority[name])
return (1, name)
openapi_schema["tags"] = sorted(tags, key=tag_sort) if LEGACY_SECURITY:
openapi_schema["security"] = LEGACY_SECURITY
tag_groups = openapi_schema.get("x-tagGroups")
if isinstance(tag_groups, list) and LEGACY_TAG_GROUPS:
legacy_tags = LEGACY_TAG_GROUPS[0].get("tags", [])
tag_priority = {name: idx for idx, name in enumerate(legacy_tags)}
for group in tag_groups:
group_tags = group.get("tags")
if isinstance(group_tags, list):
group["tags"] = sorted(
group_tags,
key=lambda name: (0, tag_priority[name]) if name in tag_priority else (1, name),
)
openapi_schema["x-tagGroups"] = tag_groups
return openapi_schema return openapi_schema
@ -914,12 +919,11 @@ def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI
""" """
try: try:
validate_spec(schema) validate_spec(schema)
print(f"{schema_name} is valid") print(f"{schema_name} is valid")
return True return True
except OpenAPISpecValidatorError as e: except OpenAPISpecValidatorError as e:
print(f"{schema_name} validation failed:") print(f"{schema_name} validation failed: {e}")
print(f" {e}")
return False return False
except Exception as e: except Exception as e:
print(f"{schema_name} validation error: {e}") print(f"{schema_name} validation error: {e}")
return False return False

View file

@ -14,6 +14,7 @@ from llama_stack_api import Api
# Global list to store dynamic models created during endpoint generation # Global list to store dynamic models created during endpoint generation
_dynamic_models: list[Any] = [] _dynamic_models: list[Any] = []
_dynamic_model_registry: dict[str, type] = {}
# Cache for protocol methods to avoid repeated lookups # Cache for protocol methods to avoid repeated lookups
_protocol_methods_cache: dict[Api, dict[str, Any]] | None = None _protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
@ -21,3 +22,20 @@ _protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
# Global dict to store extra body field information by endpoint # Global dict to store extra body field information by endpoint
# Key: (path, method) tuple, Value: list of (param_name, param_type, description) tuples # Key: (path, method) tuple, Value: list of (param_name, param_type, description) tuples
_extra_body_fields: dict[tuple[str, str], list[tuple[str, type, str | None]]] = {} _extra_body_fields: dict[tuple[str, str], list[tuple[str, type, str | None]]] = {}
def register_dynamic_model(name: str, model: type) -> type:
"""Register and deduplicate dynamically generated request models."""
existing = _dynamic_model_registry.get(name)
if existing is not None:
return existing
_dynamic_model_registry[name] = model
_dynamic_models.append(model)
return model
def reset_generator_state() -> None:
"""Clear per-run caches so repeated generations stay deterministic."""
_dynamic_models.clear()
_dynamic_model_registry.clear()
_extra_body_fields.clear()