restore some types we had erased

This commit is contained in:
Ashwin Bharambe 2025-11-14 15:03:36 -08:00
parent 0c9ffff1b8
commit e3e8272bbe
5 changed files with 13830 additions and 6 deletions

View file

@ -10,6 +10,7 @@ Schema filtering and version filtering for OpenAPI generation.
from typing import Any
from llama_stack_api.schema_utils import iter_json_schema_types, iter_registered_schema_types
from llama_stack_api.version import (
LLAMA_STACK_API_V1,
LLAMA_STACK_API_V1ALPHA,
@ -17,6 +18,23 @@ from llama_stack_api.version import (
)
def _get_all_json_schema_type_names() -> set[str]:
"""Collect schema names from @json_schema_type-decorated models."""
schema_names = set()
for model in iter_json_schema_types():
schema_name = getattr(model, "_llama_stack_schema_name", None) or getattr(model, "__name__", None)
if schema_name:
schema_names.add(schema_name)
return schema_names
def _get_explicit_schema_names(openapi_schema: dict[str, Any]) -> set[str]:
"""Schema names to keep even if not referenced by a path."""
registered_schema_names = {info.name for info in iter_registered_schema_types()}
json_schema_type_names = _get_all_json_schema_type_names()
return registered_schema_names | json_schema_type_names
def _find_schema_refs_in_object(obj: Any) -> set[str]:
"""
Recursively find all schema references ($ref) in an object.
@ -37,12 +55,21 @@ def _find_schema_refs_in_object(obj: Any) -> set[str]:
return refs
def _add_transitive_references(referenced_schemas: set[str], all_schemas: dict[str, Any]) -> set[str]:
def _add_transitive_references(
referenced_schemas: set[str], all_schemas: dict[str, Any], initial_schemas: set[str] | None = None
) -> set[str]:
"""Add transitive references for given schemas."""
additional_schemas = set()
for schema_name in referenced_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
if initial_schemas:
referenced_schemas.update(initial_schemas)
additional_schemas = set()
for schema_name in initial_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
else:
additional_schemas = set()
for schema_name in referenced_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
while additional_schemas:
new_schemas = additional_schemas - referenced_schemas
@ -113,7 +140,8 @@ def _filter_schemas_by_references(
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
referenced_schemas = _add_transitive_references(referenced_schemas, all_schemas)
explicit_names = _get_explicit_schema_names(openapi_schema)
referenced_schemas = _add_transitive_references(referenced_schemas, all_schemas, explicit_names)
filtered_schemas = {
name: schema for name, schema in filtered_schema["components"]["schemas"].items() if name in referenced_schemas