llama-stack-mirror/tests/unit/server/test_schema_registry.py

48 lines
1.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack_api import Conversation, SamplingStrategy
from llama_stack_api.schema_utils import (
clear_dynamic_schema_types,
get_registered_schema_info,
iter_dynamic_schema_types,
iter_json_schema_types,
iter_registered_schema_types,
register_dynamic_schema_type,
)
def test_json_schema_registry_contains_known_model() -> None:
assert Conversation in iter_json_schema_types()
def test_registered_schema_registry_contains_sampling_strategy() -> None:
registered_names = {info.name for info in iter_registered_schema_types()}
assert "SamplingStrategy" in registered_names
schema_info = get_registered_schema_info(SamplingStrategy)
assert schema_info is not None
assert schema_info.name == "SamplingStrategy"
def test_dynamic_schema_registration_round_trip() -> None:
existing_models = tuple(iter_dynamic_schema_types())
clear_dynamic_schema_types()
try:
class TemporaryModel(BaseModel):
foo: str
register_dynamic_schema_type(TemporaryModel)
assert TemporaryModel in iter_dynamic_schema_types()
clear_dynamic_schema_types()
assert TemporaryModel not in iter_dynamic_schema_types()
finally:
for model in existing_models:
register_dynamic_schema_type(model)