From 52503490d8761e601a50bef6913c916ec1fcd39e Mon Sep 17 00:00:00 2001 From: Ignas Baranauskas Date: Tue, 19 Aug 2025 11:39:00 +0100 Subject: [PATCH] test: add tests for model not persistant models --- .../routers/test_routing_tables.py | 343 +++++++++++++----- 1 file changed, 250 insertions(+), 93 deletions(-) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index bbfea3f46..13fcc93f1 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -17,7 +17,7 @@ from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.core.datatypes import RegistryEntrySource +from llama_stack.core.datatypes import RegistryEntrySource, StackRunConfig from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable from llama_stack.core.routing_tables.models import ModelsRoutingTable @@ -534,114 +534,271 @@ async def test_models_source_tracking_provider(cached_disk_dist_registry): await table.shutdown() -async def test_models_source_interaction_preserves_default(cached_disk_dist_registry): - """Test that provider refresh preserves user-registered models with default source.""" - table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) +async def test_models_dynamic_from_config_generation(cached_disk_dist_registry): + """Test that from_config models are generated dynamically from run_config.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() - # First register a user model with same provider_resource_id as provider will later provide - await table.register_model( - model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider" + # Test that no from_config models are registered when no run_config + all_models = await table.get_all_with_type("model") + from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config] + assert len(from_config_models) == 0 + + # Create a run config with from_config models + run_config = StackRunConfig( + image_name="test", + providers={}, + models=[ + { + "model_id": "from_config_model_1", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-3.5-turbo", + }, + { + "model_id": "from_config_model_2", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-4", + }, + ], ) - # Verify user model is registered with default source - models = await table.list_models() - assert len(models.data) == 1 - user_model = models.data[0] - assert user_model.source == RegistryEntrySource.via_register_api - assert user_model.identifier == "my-custom-alias" - assert user_model.provider_resource_id == "provider-model-1" + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() - # Now simulate provider refresh - provider_models = [ - Model( - identifier="provider-model-1", - provider_resource_id="provider-model-1", - provider_id="test_provider", - metadata={}, - model_type=ModelType.llm, - ), - Model( - identifier="different-model", - provider_resource_id="different-model", - provider_id="test_provider", - metadata={}, - model_type=ModelType.llm, - ), - ] - await table.update_registered_models("test_provider", provider_models) + # Test that from_config models are registered in the registry + all_models = await table.get_all_with_type("model") + from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config] + assert len(from_config_models) == 2 - # Verify user model with alias is preserved, but provider added new model - models = await table.list_models() - assert len(models.data) == 2 + model_identifiers = {m.identifier for m in from_config_models} + assert "from_config_model_1" in model_identifiers + assert "from_config_model_2" in model_identifiers - # Find the user model and provider model - user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) - provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None) - - assert user_model is not None - assert user_model.source == RegistryEntrySource.via_register_api - assert user_model.provider_resource_id == "provider-model-1" - - assert provider_model is not None - assert provider_model.source == RegistryEntrySource.listed_from_provider - assert provider_model.provider_resource_id == "different-model" + # Test that from_config models have correct attributes + model_1 = next(m for m in from_config_models if m.identifier == "from_config_model_1") + assert model_1.provider_id == "test_provider" + assert model_1.provider_resource_id == "gpt-3.5-turbo" + assert model_1.model_type == ModelType.llm + assert model_1.source == RegistryEntrySource.from_config # Cleanup await table.shutdown() -async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry): - """Test that provider refresh removes old provider models but keeps default ones.""" - table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) +async def test_models_dynamic_from_config_lookup(cached_disk_dist_registry): + """Test that from_config models can be looked up individually.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() - # Register a user model - await table.register_model(model_id="user-model", provider_id="test_provider") + # Create a run config with from_config models + run_config = StackRunConfig( + image_name="test", + providers={}, + models=[ + { + "model_id": "lookup_test_model", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-3.5-turbo", + } + ], + ) - # Add some provider models - provider_models_v1 = [ - Model( - identifier="provider-model-old", - provider_resource_id="provider-model-old", - provider_id="test_provider", - metadata={}, - model_type=ModelType.llm, - ), - ] - await table.update_registered_models("test_provider", provider_models_v1) + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() - # Verify we have both user and provider models - models = await table.list_models() - assert len(models.data) == 2 - - # Now update with new provider models (should remove old provider models) - provider_models_v2 = [ - Model( - identifier="provider-model-new", - provider_resource_id="provider-model-new", - provider_id="test_provider", - metadata={}, - model_type=ModelType.llm, - ), - ] - await table.update_registered_models("test_provider", provider_models_v2) - - # Should have user model + new provider model, old provider model gone - models = await table.list_models() - assert len(models.data) == 2 - - identifiers = {m.identifier for m in models.data} - assert "test_provider/user-model" in identifiers # User model preserved - assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier) - assert "test_provider/provider-model-old" not in identifiers # Old provider model removed - - # Verify sources are correct - user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None) - provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None) - - assert user_model.source == RegistryEntrySource.via_register_api - assert provider_model.source == RegistryEntrySource.listed_from_provider + # Test that we can get the from_config model individually + model = await table.get_model("lookup_test_model") + assert model is not None + assert model.identifier == "lookup_test_model" + assert model.provider_id == "test_provider" + assert model.provider_resource_id == "gpt-3.5-turbo" + assert model.source == RegistryEntrySource.from_config + + # Cleanup + await table.shutdown() + + +async def test_models_dynamic_from_config_mixed_with_persistent(cached_disk_dist_registry): + """Test that from_config models work alongside persistent models.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) + await table.initialize() + + # Create a run config with from_config models + run_config = StackRunConfig( + image_name="test", + providers={}, + models=[ + { + "model_id": "from_config_model", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-3.5-turbo", + } + ], + ) + + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() + + # Test that from_config models are included + models = await table.list_models() + from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config] + + assert len(from_config_models) == 1 + assert from_config_models[0].identifier == "from_config_model" + + # Test that we can get the from_config model individually + from_config_model = await table.get_model("from_config_model") + assert from_config_model is not None + assert from_config_model.source == RegistryEntrySource.from_config + + # Cleanup + await table.shutdown() + + +async def test_models_dynamic_from_config_disabled_providers(cached_disk_dist_registry): + """Test that from_config models with disabled providers are skipped.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) + await table.initialize() + + # Create a run config with disabled provider models + run_config = StackRunConfig( + image_name="test", + providers={}, + models=[ + { + "model_id": "enabled_model", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-3.5-turbo", + }, + { + "model_id": "disabled_model", + "provider_id": "__disabled__", + "model_type": "llm", + "provider_model_id": "gpt-4", + }, + ], + ) + + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() + + # Test that only enabled models are included + all_models = await table.get_all_with_type("model") + from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config] + assert len(from_config_models) == 1 + assert from_config_models[0].identifier == "enabled_model" + + # Cleanup + await table.shutdown() + + +async def test_models_dynamic_from_config_no_run_config(cached_disk_dist_registry): + """Test that from_config models work when no run_config is set.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) + await table.initialize() + + # Test that list_models works without run_config + models = await table.list_models() + from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config] + assert len(from_config_models) == 0 # No from_config models when no run_config + + # Cleanup + await table.shutdown() + + +async def test_models_filter_persistent_models_from_removed_providers(cached_disk_dist_registry): + """Test that models from removed providers are filtered out from persistent models.""" + from llama_stack.apis.models import ModelType + from llama_stack.core.datatypes import ModelWithOwner, Provider, RegistryEntrySource, StackRunConfig + from llama_stack.core.routing_tables.models import ModelsRoutingTable + + # Create a routing table + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) + await table.initialize() + + # Create some mock persistent models + model1 = ModelWithOwner( + identifier="test_provider_1/model1", + provider_resource_id="model1", + provider_id="test_provider_1", + metadata={}, + model_type=ModelType.llm, + source=RegistryEntrySource.listed_from_provider, + ) + model2 = ModelWithOwner( + identifier="test_provider_2/model2", + provider_resource_id="model2", + provider_id="test_provider_2", + metadata={}, + model_type=ModelType.llm, + source=RegistryEntrySource.listed_from_provider, + ) + user_model = ModelWithOwner( + identifier="user_model", + provider_resource_id="user_model", + provider_id="test_provider_1", + metadata={}, + model_type=ModelType.llm, + source=RegistryEntrySource.via_register_api, + ) + + # Create a run config that only includes test_provider_1 (test_provider_2 is removed) + run_config = StackRunConfig( + image_name="test", + providers={ + "inference": [ + Provider(provider_id="test_provider_1", provider_type="openai", config={"api_key": "test_key"}), + # test_provider_2 is removed from run.yaml + ] + }, + models=[], + ) + + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() + + # Test the cleanup logic directly + # First, manually add models to the registry to simulate existing models + await table.dist_registry.register(model1) + await table.dist_registry.register(model2) + await table.dist_registry.register(user_model) + + # Now set the run config which should trigger cleanup + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() + + # Get the list of models after cleanup + response = await table.list_models() + model_identifiers = {m.identifier for m in response.data} + + # Should have user_model (user-registered) and model1 (from enabled provider), but not model2 (from disabled provider) + # model1 should be kept because test_provider_1 is in the run config (enabled) + # model2 should be removed because test_provider_2 is not in the run config (disabled) + # user_model should be kept because it's user-registered + assert "user_model" in model_identifiers + assert "test_provider_1/model1" in model_identifiers + assert "test_provider_2/model2" not in model_identifiers + + # Test that user-registered models are always kept regardless of provider status + user_model_found = next((m for m in response.data if m.identifier == "user_model"), None) + assert user_model_found is not None + assert user_model_found.source == RegistryEntrySource.via_register_api # Cleanup await table.shutdown()