restore some types we had erased

This commit is contained in:
Ashwin Bharambe 2025-11-14 15:03:36 -08:00
parent 0c9ffff1b8
commit e3e8272bbe
5 changed files with 13830 additions and 6 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -10,6 +10,7 @@ Schema filtering and version filtering for OpenAPI generation.
from typing import Any from typing import Any
from llama_stack_api.schema_utils import iter_json_schema_types, iter_registered_schema_types
from llama_stack_api.version import ( from llama_stack_api.version import (
LLAMA_STACK_API_V1, LLAMA_STACK_API_V1,
LLAMA_STACK_API_V1ALPHA, LLAMA_STACK_API_V1ALPHA,
@ -17,6 +18,23 @@ from llama_stack_api.version import (
) )
def _get_all_json_schema_type_names() -> set[str]:
"""Collect schema names from @json_schema_type-decorated models."""
schema_names = set()
for model in iter_json_schema_types():
schema_name = getattr(model, "_llama_stack_schema_name", None) or getattr(model, "__name__", None)
if schema_name:
schema_names.add(schema_name)
return schema_names
def _get_explicit_schema_names(openapi_schema: dict[str, Any]) -> set[str]:
"""Schema names to keep even if not referenced by a path."""
registered_schema_names = {info.name for info in iter_registered_schema_types()}
json_schema_type_names = _get_all_json_schema_type_names()
return registered_schema_names | json_schema_type_names
def _find_schema_refs_in_object(obj: Any) -> set[str]: def _find_schema_refs_in_object(obj: Any) -> set[str]:
""" """
Recursively find all schema references ($ref) in an object. Recursively find all schema references ($ref) in an object.
@ -37,8 +55,17 @@ def _find_schema_refs_in_object(obj: Any) -> set[str]:
return refs return refs
def _add_transitive_references(referenced_schemas: set[str], all_schemas: dict[str, Any]) -> set[str]: def _add_transitive_references(
referenced_schemas: set[str], all_schemas: dict[str, Any], initial_schemas: set[str] | None = None
) -> set[str]:
"""Add transitive references for given schemas.""" """Add transitive references for given schemas."""
if initial_schemas:
referenced_schemas.update(initial_schemas)
additional_schemas = set()
for schema_name in initial_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
else:
additional_schemas = set() additional_schemas = set()
for schema_name in referenced_schemas: for schema_name in referenced_schemas:
if schema_name in all_schemas: if schema_name in all_schemas:
@ -113,7 +140,8 @@ def _filter_schemas_by_references(
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema) referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
all_schemas = openapi_schema.get("components", {}).get("schemas", {}) all_schemas = openapi_schema.get("components", {}).get("schemas", {})
referenced_schemas = _add_transitive_references(referenced_schemas, all_schemas) explicit_names = _get_explicit_schema_names(openapi_schema)
referenced_schemas = _add_transitive_references(referenced_schemas, all_schemas, explicit_names)
filtered_schemas = { filtered_schemas = {
name: schema for name, schema in filtered_schema["components"]["schemas"].items() if name in referenced_schemas name: schema for name, schema in filtered_schema["components"]["schemas"].items() if name in referenced_schemas