# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from pydantic import BaseModel from llama_stack_api import Conversation, SamplingStrategy from llama_stack_api.schema_utils import ( clear_dynamic_schema_types, get_registered_schema_info, iter_dynamic_schema_types, iter_json_schema_types, iter_registered_schema_types, register_dynamic_schema_type, ) def test_json_schema_registry_contains_known_model() -> None: assert Conversation in iter_json_schema_types() def test_registered_schema_registry_contains_sampling_strategy() -> None: registered_names = {info.name for info in iter_registered_schema_types()} assert "SamplingStrategy" in registered_names schema_info = get_registered_schema_info(SamplingStrategy) assert schema_info is not None assert schema_info.name == "SamplingStrategy" def test_dynamic_schema_registration_round_trip() -> None: existing_models = tuple(iter_dynamic_schema_types()) clear_dynamic_schema_types() try: class TemporaryModel(BaseModel): foo: str register_dynamic_schema_type(TemporaryModel) assert TemporaryModel in iter_dynamic_schema_types() clear_dynamic_schema_types() assert TemporaryModel not in iter_dynamic_schema_types() finally: for model in existing_models: register_dynamic_schema_type(model)