mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
schema naming cleanup, should be much closer
This commit is contained in:
parent
69e1176ff8
commit
5293b4e5e9
11 changed files with 20459 additions and 12557 deletions
File diff suppressed because it is too large
Load diff
5466
docs/static/deprecated-llama-stack-spec.yaml
vendored
5466
docs/static/deprecated-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
4418
docs/static/experimental-llama-stack-spec.yaml
vendored
4418
docs/static/experimental-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
6874
docs/static/llama-stack-spec.yaml
vendored
6874
docs/static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
7977
docs/static/stainless-llama-stack-spec.yaml
vendored
7977
docs/static/stainless-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
|
@ -12,7 +12,6 @@ These lists help the new generator match the previous ordering so that diffs
|
|||
remain readable while we debug schema content regressions. Remove once stable.
|
||||
"""
|
||||
|
||||
# TODO: remove once generator output stabilizes
|
||||
LEGACY_PATH_ORDER = ['/v1/batches',
|
||||
'/v1/batches/{batch_id}',
|
||||
'/v1/batches/{batch_id}/cancel',
|
||||
|
|
@ -364,6 +363,57 @@ LEGACY_SCHEMA_ORDER = ['Error',
|
|||
|
||||
LEGACY_RESPONSE_ORDER = ['BadRequest400', 'TooManyRequests429', 'InternalServerError500', 'DefaultError']
|
||||
|
||||
LEGACY_TAGS = [{'description': 'APIs for creating and interacting with agentic systems.',
|
||||
'name': 'Agents',
|
||||
'x-displayName': 'Agents'},
|
||||
{'description': 'The API is designed to allow use of openai client libraries for seamless integration.\n'
|
||||
'\n'
|
||||
'This API provides the following extensions:\n'
|
||||
' - idempotent batch creation\n'
|
||||
'\n'
|
||||
'Note: This API is currently under active development and may undergo changes.',
|
||||
'name': 'Batches',
|
||||
'x-displayName': 'The Batches API enables efficient processing of multiple requests in a single operation, '
|
||||
'particularly useful for processing large datasets, batch evaluation workflows, and cost-effective '
|
||||
'inference at scale.'},
|
||||
{'description': '', 'name': 'Benchmarks'},
|
||||
{'description': 'Protocol for conversation management operations.',
|
||||
'name': 'Conversations',
|
||||
'x-displayName': 'Conversations'},
|
||||
{'description': '', 'name': 'DatasetIO'},
|
||||
{'description': '', 'name': 'Datasets'},
|
||||
{'description': 'Llama Stack Evaluation API for running evaluations on model and agent candidates.',
|
||||
'name': 'Eval',
|
||||
'x-displayName': 'Evaluations'},
|
||||
{'description': 'This API is used to upload documents that can be used with other Llama Stack APIs.',
|
||||
'name': 'Files',
|
||||
'x-displayName': 'Files'},
|
||||
{'description': 'Llama Stack Inference API for generating completions, chat completions, and embeddings.\n'
|
||||
'\n'
|
||||
'This API provides the raw interface to the underlying models. Three kinds of models are supported:\n'
|
||||
'- LLM models: these models generate "raw" and "chat" (conversational) completions.\n'
|
||||
'- Embedding models: these models generate embeddings to be used for semantic search.\n'
|
||||
'- Rerank models: these models reorder the documents based on their relevance to a query.',
|
||||
'name': 'Inference',
|
||||
'x-displayName': 'Inference'},
|
||||
{'description': 'APIs for inspecting the Llama Stack service, including health status, available API routes with '
|
||||
'methods and implementing providers.',
|
||||
'name': 'Inspect',
|
||||
'x-displayName': 'Inspect'},
|
||||
{'description': '', 'name': 'Models'},
|
||||
{'description': '', 'name': 'PostTraining (Coming Soon)'},
|
||||
{'description': 'Protocol for prompt management operations.', 'name': 'Prompts', 'x-displayName': 'Prompts'},
|
||||
{'description': 'Providers API for inspecting, listing, and modifying providers and their configurations.',
|
||||
'name': 'Providers',
|
||||
'x-displayName': 'Providers'},
|
||||
{'description': 'OpenAI-compatible Moderations API.', 'name': 'Safety', 'x-displayName': 'Safety'},
|
||||
{'description': '', 'name': 'Scoring'},
|
||||
{'description': '', 'name': 'ScoringFunctions'},
|
||||
{'description': '', 'name': 'Shields'},
|
||||
{'description': '', 'name': 'ToolGroups'},
|
||||
{'description': '', 'name': 'ToolRuntime'},
|
||||
{'description': '', 'name': 'VectorIO'}]
|
||||
|
||||
LEGACY_TAG_ORDER = ['Agents',
|
||||
'Batches',
|
||||
'Benchmarks',
|
||||
|
|
@ -408,3 +458,16 @@ LEGACY_TAG_GROUPS = [{'name': 'Operations',
|
|||
'ToolGroups',
|
||||
'ToolRuntime',
|
||||
'VectorIO']}]
|
||||
|
||||
LEGACY_SECURITY = [{'Default': []}]
|
||||
|
||||
LEGACY_OPERATION_KEYS = [
|
||||
'responses',
|
||||
'tags',
|
||||
'summary',
|
||||
'description',
|
||||
'operationId',
|
||||
'parameters',
|
||||
'requestBody',
|
||||
'deprecated',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def _get_protocol_method(api: Api, method_name: str) -> Any | None:
|
|||
if _protocol_methods_cache is None:
|
||||
_protocol_methods_cache = {}
|
||||
protocols = api_protocol_map()
|
||||
from llama_stack.apis.tools import SpecialToolGroup, ToolRuntime
|
||||
from llama_stack_api.tools import SpecialToolGroup, ToolRuntime
|
||||
|
||||
toolgroup_protocols = {
|
||||
SpecialToolGroup.rag_tool: ToolRuntime,
|
||||
|
|
|
|||
|
|
@ -12,16 +12,45 @@ import inspect
|
|||
import re
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from typing import Annotated, Any, get_args, get_origin
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import Field, create_model
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Api
|
||||
|
||||
from . import app as app_module
|
||||
from .state import _dynamic_models, _extra_body_fields
|
||||
from .state import _extra_body_fields, register_dynamic_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def _to_pascal_case(segment: str) -> str:
|
||||
tokens = re.findall(r"[A-Za-z]+|\d+", segment)
|
||||
return "".join(token.capitalize() for token in tokens if token)
|
||||
|
||||
|
||||
def _compose_request_model_name(webmethod, http_method: str, variant: str | None = None) -> str:
|
||||
segments = []
|
||||
level = (webmethod.level or "").lower()
|
||||
if level and level != "v1":
|
||||
segments.append(_to_pascal_case(str(webmethod.level)))
|
||||
for part in filter(None, webmethod.route.split("/")):
|
||||
lower_part = part.lower()
|
||||
if lower_part in {"v1", "v1alpha", "v1beta"}:
|
||||
continue
|
||||
if part.startswith("{"):
|
||||
param = part[1:].split(":", 1)[0]
|
||||
segments.append(f"By{_to_pascal_case(param)}")
|
||||
else:
|
||||
segments.append(_to_pascal_case(part))
|
||||
if not segments:
|
||||
segments.append("Root")
|
||||
base_name = "".join(segments) + http_method.title() + "Request"
|
||||
if variant:
|
||||
base_name = f"{base_name}{variant}"
|
||||
return base_name
|
||||
|
||||
|
||||
def _extract_path_parameters(path: str) -> list[dict[str, Any]]:
|
||||
|
|
@ -99,21 +128,21 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_
|
|||
|
||||
|
||||
def _create_dynamic_request_model(
|
||||
webmethod, query_parameters: list[tuple[str, type, Any]], use_any: bool = False, add_uuid: bool = False
|
||||
api: Api,
|
||||
webmethod,
|
||||
http_method: str,
|
||||
query_parameters: list[tuple[str, type, Any]],
|
||||
use_any: bool = False,
|
||||
variant_suffix: str | None = None,
|
||||
) -> type | None:
|
||||
"""Create a dynamic Pydantic model for request body."""
|
||||
try:
|
||||
field_definitions = _build_field_definitions(query_parameters, use_any)
|
||||
if not field_definitions:
|
||||
return None
|
||||
clean_route = webmethod.route.replace("/", "_").replace("{", "").replace("}", "").replace("-", "_")
|
||||
model_name = f"{clean_route}_Request"
|
||||
if add_uuid:
|
||||
model_name = f"{model_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
model_name = _compose_request_model_name(webmethod, http_method, variant_suffix)
|
||||
request_model = create_model(model_name, **field_definitions)
|
||||
_dynamic_models.append(request_model)
|
||||
return request_model
|
||||
return register_dynamic_model(model_name, request_model)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
@ -190,7 +219,7 @@ def _is_file_or_form_param(param_type: Any) -> bool:
|
|||
|
||||
def _is_extra_body_field(metadata_item: Any) -> bool:
|
||||
"""Check if a metadata item is an ExtraBodyField instance."""
|
||||
from llama_stack.schema_utils import ExtraBodyField
|
||||
from llama_stack_api.schema_utils import ExtraBodyField
|
||||
|
||||
return isinstance(metadata_item, ExtraBodyField)
|
||||
|
||||
|
|
@ -232,7 +261,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
|
|||
streaming_model = inner_type
|
||||
else:
|
||||
# Might be a registered schema - check if it's registered
|
||||
from llama_stack.schema_utils import _registered_schemas
|
||||
from llama_stack_api.schema_utils import _registered_schemas
|
||||
|
||||
if inner_type in _registered_schemas:
|
||||
# We'll need to look this up later, but for now store the type
|
||||
|
|
@ -247,7 +276,7 @@ def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, t
|
|||
|
||||
def _find_models_for_endpoint(
|
||||
webmethod, api: Api, method_name: str, is_post_put: bool = False
|
||||
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None]:
|
||||
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None, str | None]:
|
||||
"""
|
||||
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
|
||||
This uses the protocol function to determine the correct models dynamically.
|
||||
|
|
@ -259,16 +288,18 @@ def _find_models_for_endpoint(
|
|||
is_post_put: Whether this is a POST, PUT, or PATCH request (GET requests should never have request bodies)
|
||||
|
||||
Returns:
|
||||
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model)
|
||||
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name)
|
||||
where query_parameters is a list of (name, type, default_value) tuples
|
||||
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
|
||||
and streaming_response_model is the model for streaming responses (AsyncIterator content)
|
||||
"""
|
||||
route_descriptor = f"{webmethod.method or 'UNKNOWN'} {webmethod.route}"
|
||||
try:
|
||||
# Get the function from the protocol
|
||||
func = app_module._get_protocol_method(api, method_name)
|
||||
if not func:
|
||||
return None, None, [], [], None
|
||||
logger.warning("No protocol method for %s.%s (%s)", api, method_name, route_descriptor)
|
||||
return None, None, [], [], None, None
|
||||
|
||||
# Analyze the function signature
|
||||
sig = inspect.signature(func)
|
||||
|
|
@ -279,6 +310,7 @@ def _find_models_for_endpoint(
|
|||
file_form_params = []
|
||||
path_params = set()
|
||||
extra_body_params = []
|
||||
response_schema_name = None
|
||||
|
||||
# Extract path parameters from the route
|
||||
if webmethod and hasattr(webmethod, "route"):
|
||||
|
|
@ -391,23 +423,49 @@ def _find_models_for_endpoint(
|
|||
elif origin is not None and (origin is types.UnionType or origin is typing.Union):
|
||||
# Handle union types - extract both non-streaming and streaming models
|
||||
response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
|
||||
else:
|
||||
try:
|
||||
from fastapi import Response as FastAPIResponse
|
||||
except ImportError:
|
||||
FastAPIResponse = None
|
||||
try:
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
except ImportError:
|
||||
StarletteResponse = None
|
||||
|
||||
return request_model, response_model, query_parameters, file_form_params, streaming_response_model
|
||||
response_types = tuple(t for t in (FastAPIResponse, StarletteResponse) if t is not None)
|
||||
if response_types and any(return_annotation is t for t in response_types):
|
||||
response_schema_name = "Response"
|
||||
|
||||
except Exception:
|
||||
# If we can't analyze the function signature, return None
|
||||
return None, None, [], [], None
|
||||
return request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to analyze endpoint %s.%s (%s): %s", api, method_name, route_descriptor, exc, exc_info=True
|
||||
)
|
||||
return None, None, [], [], None, None
|
||||
|
||||
|
||||
def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
||||
"""Create a FastAPI endpoint from a discovered route and webmethod."""
|
||||
path = route.path
|
||||
methods = route.methods
|
||||
raw_methods = route.methods or set()
|
||||
method_list = sorted({method.upper() for method in raw_methods if method and method.upper() != "HEAD"})
|
||||
if not method_list:
|
||||
method_list = ["GET"]
|
||||
primary_method = method_list[0]
|
||||
name = route.name
|
||||
fastapi_path = path.replace("{", "{").replace("}", "}")
|
||||
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
|
||||
is_post_put = any(method in ["POST", "PUT", "PATCH"] for method in method_list)
|
||||
|
||||
request_model, response_model, query_parameters, file_form_params, streaming_response_model = (
|
||||
(
|
||||
request_model,
|
||||
response_model,
|
||||
query_parameters,
|
||||
file_form_params,
|
||||
streaming_response_model,
|
||||
response_schema_name,
|
||||
) = (
|
||||
_find_models_for_endpoint(webmethod, api, name, is_post_put)
|
||||
)
|
||||
operation_description = _extract_operation_description_from_docstring(api, name)
|
||||
|
|
@ -417,7 +475,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
func = app_module._get_protocol_method(api, name)
|
||||
extra_body_params = getattr(func, "_extra_body_params", []) if func else []
|
||||
if extra_body_params:
|
||||
for method in methods:
|
||||
for method in method_list:
|
||||
key = (fastapi_path, method.upper())
|
||||
_extra_body_fields[key] = extra_body_params
|
||||
|
||||
|
|
@ -447,12 +505,11 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
|
||||
elif response_model and query_parameters:
|
||||
if is_post_put:
|
||||
# Try creating request model with type preservation, fallback to Any, then minimal
|
||||
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=False)
|
||||
request_model = _create_dynamic_request_model(api, webmethod, primary_method, query_parameters, use_any=False)
|
||||
if not request_model:
|
||||
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True)
|
||||
if not request_model:
|
||||
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True, add_uuid=True)
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
|
||||
)
|
||||
|
||||
if request_model:
|
||||
endpoint_func = _create_endpoint_with_request_model(
|
||||
|
|
@ -532,16 +589,18 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
endpoint_func = no_params_endpoint
|
||||
|
||||
# Build response content with both application/json and text/event-stream if streaming
|
||||
response_content = {}
|
||||
response_content: dict[str, Any] = {}
|
||||
if response_model:
|
||||
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
|
||||
elif response_schema_name:
|
||||
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_schema_name}"}}
|
||||
if streaming_response_model:
|
||||
# Get the schema name for the streaming model
|
||||
# It might be a registered schema or a Pydantic model
|
||||
streaming_schema_name = None
|
||||
# Check if it's a registered schema first (before checking __name__)
|
||||
# because registered schemas might be Annotated types
|
||||
from llama_stack.schema_utils import _registered_schemas
|
||||
from llama_stack_api.schema_utils import _registered_schemas
|
||||
|
||||
if streaming_response_model in _registered_schemas:
|
||||
streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
|
||||
|
|
@ -554,9 +613,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
}
|
||||
|
||||
# If no content types, use empty schema
|
||||
if not response_content:
|
||||
response_content["application/json"] = {"schema": {}}
|
||||
|
||||
# Add the endpoint to the FastAPI app
|
||||
is_deprecated = webmethod.deprecated or False
|
||||
route_kwargs = {
|
||||
|
|
@ -564,16 +620,16 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
"tags": [_get_tag_from_api(api)],
|
||||
"deprecated": is_deprecated,
|
||||
"responses": {
|
||||
200: {
|
||||
"description": response_description,
|
||||
"content": response_content,
|
||||
},
|
||||
400: {"$ref": "#/components/responses/BadRequest400"},
|
||||
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
||||
500: {"$ref": "#/components/responses/InternalServerError500"},
|
||||
"default": {"$ref": "#/components/responses/DefaultError"},
|
||||
},
|
||||
}
|
||||
success_response: dict[str, Any] = {"description": response_description}
|
||||
if response_content:
|
||||
success_response["content"] = response_content
|
||||
route_kwargs["responses"][200] = success_response
|
||||
|
||||
# FastAPI needs response_model parameter to properly generate OpenAPI spec
|
||||
# Use the non-streaming response model if available
|
||||
|
|
@ -581,6 +637,6 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
route_kwargs["response_model"] = response_model
|
||||
|
||||
method_map = {"GET": app.get, "POST": app.post, "PUT": app.put, "DELETE": app.delete, "PATCH": app.patch}
|
||||
for method in methods:
|
||||
if handler := method_map.get(method.upper()):
|
||||
for method in method_list:
|
||||
if handler := method_map.get(method):
|
||||
handler(fastapi_path, **route_kwargs)(endpoint_func)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from typing import Any
|
|||
import yaml
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from . import app, schema_collection, schema_filtering, schema_transforms
|
||||
from . import app, schema_collection, schema_filtering, schema_transforms, state
|
||||
|
||||
|
||||
def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
||||
|
|
@ -29,6 +29,7 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
|||
Returns:
|
||||
The generated OpenAPI specification as a dictionary
|
||||
"""
|
||||
state.reset_generator_state()
|
||||
# Create the FastAPI app
|
||||
fastapi_app = app.create_llama_stack_app()
|
||||
|
||||
|
|
@ -143,7 +144,7 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
|||
schema_transforms._fix_schema_issues(schema)
|
||||
schema_transforms._apply_legacy_sorting(schema)
|
||||
|
||||
print("\n🔍 Validating generated schemas...")
|
||||
print("\nValidating generated schemas...")
|
||||
failed_schemas = [
|
||||
name for schema, name in schemas_to_validate if not schema_transforms.validate_openapi_schema(schema, name)
|
||||
]
|
||||
|
|
@ -186,20 +187,20 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
|||
# Write the modified YAML back
|
||||
schema_transforms._write_yaml_file(yaml_path, yaml_data)
|
||||
|
||||
print(f"✅ Generated YAML (stable): {yaml_path}")
|
||||
print(f"Generated YAML (stable): {yaml_path}")
|
||||
|
||||
experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml"
|
||||
schema_transforms._write_yaml_file(experimental_yaml_path, experimental_schema)
|
||||
print(f"✅ Generated YAML (experimental): {experimental_yaml_path}")
|
||||
print(f"Generated YAML (experimental): {experimental_yaml_path}")
|
||||
|
||||
deprecated_yaml_path = output_path / "deprecated-llama-stack-spec.yaml"
|
||||
schema_transforms._write_yaml_file(deprecated_yaml_path, deprecated_schema)
|
||||
print(f"✅ Generated YAML (deprecated): {deprecated_yaml_path}")
|
||||
print(f"Generated YAML (deprecated): {deprecated_yaml_path}")
|
||||
|
||||
# Generate combined (stainless) spec
|
||||
stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml"
|
||||
schema_transforms._write_yaml_file(stainless_yaml_path, combined_schema)
|
||||
print(f"✅ Generated YAML (stainless/combined): {stainless_yaml_path}")
|
||||
print(f"Generated YAML (stainless/combined): {stainless_yaml_path}")
|
||||
|
||||
return stable_schema
|
||||
|
||||
|
|
@ -213,25 +214,25 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🚀 Generating OpenAPI specification using FastAPI...")
|
||||
print(f"📁 Output directory: {args.output_dir}")
|
||||
print("Generating OpenAPI specification using FastAPI...")
|
||||
print(f"Output directory: {args.output_dir}")
|
||||
|
||||
try:
|
||||
openapi_schema = generate_openapi_spec(output_dir=args.output_dir)
|
||||
|
||||
print("\n✅ OpenAPI specification generated successfully!")
|
||||
print(f"📊 Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}")
|
||||
print(f"🛣️ Paths: {len(openapi_schema.get('paths', {}))}")
|
||||
print("\nOpenAPI specification generated successfully!")
|
||||
print(f"Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}")
|
||||
print(f"Paths: {len(openapi_schema.get('paths', {}))}")
|
||||
operation_count = sum(
|
||||
1
|
||||
for path_info in openapi_schema.get("paths", {}).values()
|
||||
for method in ["get", "post", "put", "delete", "patch"]
|
||||
if method in path_info
|
||||
)
|
||||
print(f"🔧 Operations: {operation_count}")
|
||||
print(f"Operations: {operation_count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error generating OpenAPI specification: {e}")
|
||||
print(f"Error generating OpenAPI specification: {e}")
|
||||
raise
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ from ._legacy_order import (
|
|||
LEGACY_PATH_ORDER,
|
||||
LEGACY_RESPONSE_ORDER,
|
||||
LEGACY_SCHEMA_ORDER,
|
||||
LEGACY_OPERATION_KEYS,
|
||||
LEGACY_SECURITY,
|
||||
LEGACY_TAGS,
|
||||
LEGACY_TAG_GROUPS,
|
||||
LEGACY_TAG_ORDER,
|
||||
)
|
||||
|
|
@ -121,7 +124,7 @@ def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|||
openapi_schema["components"]["responses"] = {}
|
||||
|
||||
try:
|
||||
from llama_stack.apis.datatypes import Error
|
||||
from llama_stack_api.datatypes import Error
|
||||
|
||||
schema_collection._ensure_components_schemas(openapi_schema)
|
||||
if "Error" not in openapi_schema["components"]["schemas"]:
|
||||
|
|
@ -129,6 +132,10 @@ def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
schema_collection._ensure_components_schemas(openapi_schema)
|
||||
if "Response" not in openapi_schema["components"]["schemas"]:
|
||||
openapi_schema["components"]["schemas"]["Response"] = {"title": "Response", "type": "object"}
|
||||
|
||||
# Define standard HTTP error responses
|
||||
error_responses = {
|
||||
400: {
|
||||
|
|
@ -848,6 +855,20 @@ def _apply_legacy_sorting(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|||
paths = openapi_schema.get("paths")
|
||||
if isinstance(paths, dict):
|
||||
openapi_schema["paths"] = order_mapping(paths, LEGACY_PATH_ORDER)
|
||||
for path, path_item in openapi_schema["paths"].items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
ordered_path_item = OrderedDict()
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if method in path_item:
|
||||
ordered_path_item[method] = order_mapping(path_item[method], LEGACY_OPERATION_KEYS)
|
||||
for key, value in path_item.items():
|
||||
if key not in ordered_path_item:
|
||||
if isinstance(value, dict) and key.lower() in {"get", "post", "put", "delete", "patch", "head", "options"}:
|
||||
ordered_path_item[key] = order_mapping(value, LEGACY_OPERATION_KEYS)
|
||||
else:
|
||||
ordered_path_item[key] = value
|
||||
openapi_schema["paths"][path] = ordered_path_item
|
||||
|
||||
components = openapi_schema.setdefault("components", {})
|
||||
schemas = components.get("schemas")
|
||||
|
|
@ -857,30 +878,14 @@ def _apply_legacy_sorting(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|||
if isinstance(responses, dict):
|
||||
components["responses"] = order_mapping(responses, LEGACY_RESPONSE_ORDER)
|
||||
|
||||
tags = openapi_schema.get("tags")
|
||||
if isinstance(tags, list):
|
||||
tag_priority = {name: idx for idx, name in enumerate(LEGACY_TAG_ORDER)}
|
||||
if LEGACY_TAGS:
|
||||
openapi_schema["tags"] = LEGACY_TAGS
|
||||
|
||||
def tag_sort(tag_obj: dict[str, Any]) -> tuple[int, int | str]:
|
||||
name = tag_obj.get("name", "")
|
||||
if name in tag_priority:
|
||||
return (0, tag_priority[name])
|
||||
return (1, name)
|
||||
if LEGACY_TAG_GROUPS:
|
||||
openapi_schema["x-tagGroups"] = LEGACY_TAG_GROUPS
|
||||
|
||||
openapi_schema["tags"] = sorted(tags, key=tag_sort)
|
||||
|
||||
tag_groups = openapi_schema.get("x-tagGroups")
|
||||
if isinstance(tag_groups, list) and LEGACY_TAG_GROUPS:
|
||||
legacy_tags = LEGACY_TAG_GROUPS[0].get("tags", [])
|
||||
tag_priority = {name: idx for idx, name in enumerate(legacy_tags)}
|
||||
for group in tag_groups:
|
||||
group_tags = group.get("tags")
|
||||
if isinstance(group_tags, list):
|
||||
group["tags"] = sorted(
|
||||
group_tags,
|
||||
key=lambda name: (0, tag_priority[name]) if name in tag_priority else (1, name),
|
||||
)
|
||||
openapi_schema["x-tagGroups"] = tag_groups
|
||||
if LEGACY_SECURITY:
|
||||
openapi_schema["security"] = LEGACY_SECURITY
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
|
@ -914,12 +919,11 @@ def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI
|
|||
"""
|
||||
try:
|
||||
validate_spec(schema)
|
||||
print(f"✅ {schema_name} is valid")
|
||||
print(f"{schema_name} is valid")
|
||||
return True
|
||||
except OpenAPISpecValidatorError as e:
|
||||
print(f"❌ {schema_name} validation failed:")
|
||||
print(f" {e}")
|
||||
print(f"{schema_name} validation failed: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ {schema_name} validation error: {e}")
|
||||
print(f"{schema_name} validation error: {e}")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from llama_stack_api import Api
|
|||
|
||||
# Global list to store dynamic models created during endpoint generation
|
||||
_dynamic_models: list[Any] = []
|
||||
_dynamic_model_registry: dict[str, type] = {}
|
||||
|
||||
# Cache for protocol methods to avoid repeated lookups
|
||||
_protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
|
||||
|
|
@ -21,3 +22,20 @@ _protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
|
|||
# Global dict to store extra body field information by endpoint
|
||||
# Key: (path, method) tuple, Value: list of (param_name, param_type, description) tuples
|
||||
_extra_body_fields: dict[tuple[str, str], list[tuple[str, type, str | None]]] = {}
|
||||
|
||||
|
||||
def register_dynamic_model(name: str, model: type) -> type:
|
||||
"""Register and deduplicate dynamically generated request models."""
|
||||
existing = _dynamic_model_registry.get(name)
|
||||
if existing is not None:
|
||||
return existing
|
||||
_dynamic_model_registry[name] = model
|
||||
_dynamic_models.append(model)
|
||||
return model
|
||||
|
||||
|
||||
def reset_generator_state() -> None:
|
||||
"""Clear per-run caches so repeated generations stay deterministic."""
|
||||
_dynamic_models.clear()
|
||||
_dynamic_model_registry.clear()
|
||||
_extra_body_fields.clear()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue