# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. # Unit tests for the routing tables from unittest.mock import AsyncMock import pytest from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.vector_dbs import VectorDB from llama_stack.core.datatypes import RegistryEntrySource, StackRunConfig from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable from llama_stack.core.routing_tables.models import ModelsRoutingTable from llama_stack.core.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.core.routing_tables.shields import ShieldsRoutingTable from llama_stack.core.routing_tables.toolgroups import ToolGroupsRoutingTable from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: def __init__(self, api: Api): self.api = api @property def __provider_spec__(self): _provider_spec = AsyncMock() _provider_spec.api = self.api return _provider_spec class InferenceImpl(Impl): def __init__(self): super().__init__(Api.inference) async def register_model(self, model: Model): return model async def unregister_model(self, model_id: str): return model_id async def should_refresh_models(self): return False async def list_models(self): return [ Model( identifier="provider-model-1", provider_resource_id="provider-model-1", provider_id="test_provider", metadata={}, model_type=ModelType.llm, ), Model( identifier="provider-model-2", provider_resource_id="provider-model-2", provider_id="test_provider", metadata={"embedding_dimension": 512}, model_type=ModelType.embedding, ), ] async def shutdown(self): pass class SafetyImpl(Impl): def __init__(self): super().__init__(Api.safety) async def register_shield(self, shield: Shield): return shield async def unregister_shield(self, shield_id: str): return shield_id class DatasetsImpl(Impl): def __init__(self): super().__init__(Api.datasetio) async def register_dataset(self, dataset: Dataset): return dataset async def unregister_dataset(self, dataset_id: str): return dataset_id class ScoringFunctionsImpl(Impl): def __init__(self): super().__init__(Api.scoring) async def list_scoring_functions(self): return [] async def register_scoring_function(self, scoring_fn): return scoring_fn async def unregister_scoring_function(self, scoring_fn_id: str): return scoring_fn_id class BenchmarksImpl(Impl): def __init__(self): super().__init__(Api.eval) async def register_benchmark(self, benchmark): return benchmark async def unregister_benchmark(self, benchmark_id: str): return benchmark_id class ToolGroupsImpl(Impl): def __init__(self): super().__init__(Api.tool_runtime) async def register_toolgroup(self, toolgroup: ToolGroup): return toolgroup async def unregister_toolgroup(self, toolgroup_id: str): return toolgroup_id async def list_runtime_tools(self, toolgroup_id, mcp_endpoint): return ListToolDefsResponse( data=[ ToolDef( name="test-tool", description="Test tool", parameters=[ToolParameter(name="test-param", description="Test param", parameter_type="string")], ) ] ) class VectorDBImpl(Impl): def __init__(self): super().__init__(Api.vector_io) async def register_vector_db(self, vector_db: VectorDB): return vector_db async def unregister_vector_db(self, vector_db_id: str): return vector_db_id async def openai_create_vector_store(self, **kwargs): import time import uuid from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}" return VectorStoreObject( id=vector_store_id, name=kwargs.get("name", vector_store_id), created_at=int(time.time()), file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), ) async def test_models_routing_table(cached_disk_dist_registry): table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple models and verify listing await table.register_model(model_id="test-model", provider_id="test_provider") await table.register_model(model_id="test-model-2", provider_id="test_provider") models = await table.list_models() assert len(models.data) == 2 model_ids = {m.identifier for m in models.data} assert "test_provider/test-model" in model_ids assert "test_provider/test-model-2" in model_ids # Test openai list models openai_models = await table.openai_list_models() assert len(openai_models.data) == 2 openai_model_ids = {m.id for m in openai_models.data} assert "test_provider/test-model" in openai_model_ids assert "test_provider/test-model-2" in openai_model_ids # Test get_object_by_identifier model = await table.get_object_by_identifier("model", "test_provider/test-model") assert model is not None assert model.identifier == "test_provider/test-model" # Test get_object_by_identifier on non-existent object non_existent = await table.get_object_by_identifier("model", "non-existent-model") assert non_existent is None await table.unregister_model(model_id="test_provider/test-model") await table.unregister_model(model_id="test_provider/test-model-2") models = await table.list_models() assert len(models.data) == 0 # Test openai list models openai_models = await table.openai_list_models() assert len(openai_models.data) == 0 async def test_shields_routing_table(cached_disk_dist_registry): table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple shields and verify listing await table.register_shield(shield_id="test-shield", provider_id="test_provider") await table.register_shield(shield_id="test-shield-2", provider_id="test_provider") shields = await table.list_shields() assert len(shields.data) == 2 shield_ids = {s.identifier for s in shields.data} assert "test-shield" in shield_ids assert "test-shield-2" in shield_ids # Test get specific shield test_shield = await table.get_shield(identifier="test-shield") assert test_shield is not None assert test_shield.identifier == "test-shield" assert test_shield.provider_id == "test_provider" assert test_shield.provider_resource_id == "test-shield" assert test_shield.params == {} # Test get non-existent shield - should raise ValueError with specific message with pytest.raises(ValueError, match="Shield 'non-existent' not found"): await table.get_shield(identifier="non-existent") # Test unregistering shields await table.unregister_shield(identifier="test-shield") shields = await table.list_shields() assert len(shields.data) == 1 shield_ids = {s.identifier for s in shields.data} assert "test-shield" not in shield_ids assert "test-shield-2" in shield_ids # Unregister the remaining shield await table.unregister_shield(identifier="test-shield-2") shields = await table.list_shields() assert len(shields.data) == 0 # Test unregistering non-existent shield - should raise ValueError with specific message with pytest.raises(ValueError, match="Shield 'non-existent' not found"): await table.unregister_shield(identifier="non-existent") async def test_vectordbs_routing_table(cached_disk_dist_registry): table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) await table.initialize() m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await m_table.initialize() await m_table.register_model( model_id="test-model", provider_id="test_provider", metadata={"embedding_dimension": 128}, model_type=ModelType.embedding, ) # Register multiple vector databases and verify listing vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 2 vector_db_ids = {v.identifier for v in vector_dbs.data} assert vdb1.identifier in vector_db_ids assert vdb2.identifier in vector_db_ids # Verify they have UUID-based identifiers assert vdb1.identifier.startswith("vs_") assert vdb2.identifier.startswith("vs_") await table.unregister_vector_db(vector_db_id=vdb1.identifier) await table.unregister_vector_db(vector_db_id=vdb2.identifier) vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 0 async def test_datasets_routing_table(cached_disk_dist_registry): table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple datasets and verify listing await table.register_dataset( dataset_id="test-dataset", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource(uri="test-uri") ) await table.register_dataset( dataset_id="test-dataset-2", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource(uri="test-uri-2") ) datasets = await table.list_datasets() assert len(datasets.data) == 2 dataset_ids = {d.identifier for d in datasets.data} assert "test-dataset" in dataset_ids assert "test-dataset-2" in dataset_ids await table.unregister_dataset(dataset_id="test-dataset") await table.unregister_dataset(dataset_id="test-dataset-2") datasets = await table.list_datasets() assert len(datasets.data) == 0 async def test_scoring_functions_routing_table(cached_disk_dist_registry): table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple scoring functions and verify listing await table.register_scoring_function( scoring_fn_id="test-scoring-fn", provider_id="test_provider", description="Test scoring function", return_type=NumberType(), ) await table.register_scoring_function( scoring_fn_id="test-scoring-fn-2", provider_id="test_provider", description="Another test scoring function", return_type=NumberType(), ) scoring_functions = await table.list_scoring_functions() assert len(scoring_functions.data) == 2 scoring_fn_ids = {fn.identifier for fn in scoring_functions.data} assert "test-scoring-fn" in scoring_fn_ids assert "test-scoring-fn-2" in scoring_fn_ids # Unregister scoring functions and verify listing for i in range(len(scoring_functions.data)): await table.unregister_scoring_function(scoring_functions.data[i].scoring_fn_id) scoring_functions_list_after_deletion = await table.list_scoring_functions() assert len(scoring_functions_list_after_deletion.data) == 0 async def test_benchmarks_routing_table(cached_disk_dist_registry): table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple benchmarks and verify listing await table.register_benchmark( benchmark_id="test-benchmark", dataset_id="test-dataset", scoring_functions=["test-scoring-fn", "test-scoring-fn-2"], ) benchmarks = await table.list_benchmarks() assert len(benchmarks.data) == 1 benchmark_ids = {b.identifier for b in benchmarks.data} assert "test-benchmark" in benchmark_ids # Unregister the benchmark and verify removal await table.unregister_benchmark(benchmark_id="test-benchmark") benchmarks_after = await table.list_benchmarks() assert len(benchmarks_after.data) == 0 # Unregistering a non-existent benchmark should raise a clear error with pytest.raises(ValueError, match="Benchmark 'dummy_benchmark' not found"): await table.unregister_benchmark(benchmark_id="dummy_benchmark") async def test_tool_groups_routing_table(cached_disk_dist_registry): table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register multiple tool groups and verify listing await table.register_tool_group( toolgroup_id="test-toolgroup", provider_id="test_provider", ) tool_groups = await table.list_tool_groups() assert len(tool_groups.data) == 1 tool_group_ids = {tg.identifier for tg in tool_groups.data} assert "test-toolgroup" in tool_group_ids await table.unregister_toolgroup(toolgroup_id="test-toolgroup") tool_groups = await table.list_tool_groups() assert len(tool_groups.data) == 0 async def test_models_alias_registration_and_lookup(cached_disk_dist_registry): """Test alias registration (model_id != provider_model_id) and lookup behavior.""" table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register model with alias (model_id different from provider_model_id) await table.register_model( model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider" ) # Verify the model was registered with alias as identifier (not namespaced) models = await table.list_models() assert len(models.data) == 1 model = models.data[0] assert model.identifier == "my-alias" # Uses alias as identifier assert model.provider_resource_id == "actual-provider-model" # Test lookup by alias works retrieved_model = await table.get_model("my-alias") assert retrieved_model.identifier == "my-alias" assert retrieved_model.provider_resource_id == "actual-provider-model" async def test_models_multi_provider_disambiguation(cached_disk_dist_registry): """Test registration and lookup with multiple providers having same provider_model_id.""" table = ModelsRoutingTable( {"provider1": InferenceImpl(), "provider2": InferenceImpl()}, cached_disk_dist_registry, {} ) await table.initialize() # Register same provider_model_id on both providers (no aliases) await table.register_model(model_id="common-model", provider_id="provider1") await table.register_model(model_id="common-model", provider_id="provider2") # Verify both models get namespaced identifiers models = await table.list_models() assert len(models.data) == 2 identifiers = {m.identifier for m in models.data} assert identifiers == {"provider1/common-model", "provider2/common-model"} # Test lookup by full namespaced identifier works model1 = await table.get_model("provider1/common-model") assert model1.provider_id == "provider1" assert model1.provider_resource_id == "common-model" model2 = await table.get_model("provider2/common-model") assert model2.provider_id == "provider2" assert model2.provider_resource_id == "common-model" # Test lookup by unscoped provider_model_id fails with multiple providers error try: await table.get_model("common-model") raise AssertionError("Should have raised ValueError for multiple providers") except ValueError as e: assert "Multiple providers found" in str(e) assert "provider1" in str(e) and "provider2" in str(e) async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): """Test two-stage lookup: direct identifier hit vs fallback to provider_resource_id.""" table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register model without alias (gets namespaced identifier) await table.register_model(model_id="test-model", provider_id="test_provider") # Verify namespaced identifier was created models = await table.list_models() assert len(models.data) == 1 model = models.data[0] assert model.identifier == "test_provider/test-model" assert model.provider_resource_id == "test-model" # Test lookup by full namespaced identifier (direct hit via get_object_by_identifier) retrieved_model = await table.get_model("test_provider/test-model") assert retrieved_model.identifier == "test_provider/test-model" # Test lookup by unscoped provider_model_id (fallback via iteration) retrieved_model = await table.get_model("test-model") assert retrieved_model.identifier == "test_provider/test-model" assert retrieved_model.provider_resource_id == "test-model" # Test lookup of non-existent model fails try: await table.get_model("non-existent") raise AssertionError("Should have raised ValueError for non-existent model") except ValueError as e: assert "not found" in str(e) async def test_models_source_tracking_default(cached_disk_dist_registry): """Test that models registered via register_model get default source.""" table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Register model via register_model (should get default source) await table.register_model(model_id="user-model", provider_id="test_provider") models = await table.list_models() assert len(models.data) == 1 model = models.data[0] assert model.source == RegistryEntrySource.via_register_api assert model.identifier == "test_provider/user-model" # Cleanup await table.shutdown() async def test_models_source_tracking_provider(cached_disk_dist_registry): """Test that models registered via update_registered_models get provider source.""" table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() # Simulate provider refresh by calling update_registered_models provider_models = [ Model( identifier="provider-model-1", provider_resource_id="provider-model-1", provider_id="test_provider", metadata={}, model_type=ModelType.llm, ), Model( identifier="provider-model-2", provider_resource_id="provider-model-2", provider_id="test_provider", metadata={"embedding_dimension": 512}, model_type=ModelType.embedding, ), ] await table.update_registered_models("test_provider", provider_models) models = await table.list_models() assert len(models.data) == 2 # All models should have provider source for model in models.data: assert model.source == RegistryEntrySource.listed_from_provider assert model.provider_id == "test_provider" # Cleanup await table.shutdown() async def test_models_dynamic_from_config_generation(cached_disk_dist_registry): """Test that from_config models are generated dynamically from run_config.""" table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() # Test that no from_config models are registered when no run_config all_models = await table.get_all_with_type("model") from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config] assert len(from_config_models) == 0 # Create a run config with from_config models run_config = StackRunConfig( image_name="test", providers={}, models=[ { "model_id": "from_config_model_1", "provider_id": "test_provider", "model_type": "llm", "provider_model_id": "gpt-3.5-turbo", }, { "model_id": "from_config_model_2", "provider_id": "test_provider", "model_type": "llm", "provider_model_id": "gpt-4", }, ], ) # Set the run config table.current_run_config = run_config await table.cleanup_disabled_provider_models() await table.register_from_config_models() # Test that from_config models are registered in the registry all_models = await table.get_all_with_type("model") from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config] assert len(from_config_models) == 2 model_identifiers = {m.identifier for m in from_config_models} assert "from_config_model_1" in model_identifiers assert "from_config_model_2" in model_identifiers # Test that from_config models have correct attributes model_1 = next(m for m in from_config_models if m.identifier == "from_config_model_1") assert model_1.provider_id == "test_provider" assert model_1.provider_resource_id == "gpt-3.5-turbo" assert model_1.model_type == ModelType.llm assert model_1.source == RegistryEntrySource.from_config # Cleanup await table.shutdown() async def test_models_dynamic_from_config_lookup(cached_disk_dist_registry): """Test that from_config models can be looked up individually.""" table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() # Create a run config with from_config models run_config = StackRunConfig( image_name="test", providers={}, models=[ { "model_id": "lookup_test_model", "provider_id": "test_provider", "model_type": "llm", "provider_model_id": "gpt-3.5-turbo", } ], ) # Set the run config table.current_run_config = run_config await table.cleanup_disabled_provider_models() await table.register_from_config_models() # Test that we can get the from_config model individually model = await table.get_model("lookup_test_model") assert model is not None assert model.identifier == "lookup_test_model" assert model.provider_id == "test_provider" assert model.provider_resource_id == "gpt-3.5-turbo" assert model.source == RegistryEntrySource.from_config # Cleanup await table.shutdown() async def test_models_dynamic_from_config_mixed_with_persistent(cached_disk_dist_registry): """Test that from_config models work alongside persistent models.""" table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() # Create a run config with from_config models run_config = StackRunConfig( image_name="test", providers={}, models=[ { "model_id": "from_config_model", "provider_id": "test_provider", "model_type": "llm", "provider_model_id": "gpt-3.5-turbo", } ], ) # Set the run config table.current_run_config = run_config await table.cleanup_disabled_provider_models() await table.register_from_config_models() # Test that from_config models are included models = await table.list_models() from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config] assert len(from_config_models) == 1 assert from_config_models[0].identifier == "from_config_model" # Test that we can get the from_config model individually from_config_model = await table.get_model("from_config_model") assert from_config_model is not None assert from_config_model.source == RegistryEntrySource.from_config # Cleanup await table.shutdown() async def test_models_dynamic_from_config_disabled_providers(cached_disk_dist_registry): """Test that from_config models with disabled providers are skipped.""" table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() # Create a run config with disabled provider models run_config = StackRunConfig( image_name="test", providers={}, models=[ { "model_id": "enabled_model", "provider_id": "test_provider", "model_type": "llm", "provider_model_id": "gpt-3.5-turbo", }, { "model_id": "disabled_model", "provider_id": "__disabled__", "model_type": "llm", "provider_model_id": "gpt-4", }, ], ) # Set the run config table.current_run_config = run_config await table.cleanup_disabled_provider_models() await table.register_from_config_models() # Test that only enabled models are included all_models = await table.get_all_with_type("model") from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config] assert len(from_config_models) == 1 assert from_config_models[0].identifier == "enabled_model" # Cleanup await table.shutdown() async def test_models_dynamic_from_config_no_run_config(cached_disk_dist_registry): """Test that from_config models work when no run_config is set.""" table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() # Test that list_models works without run_config models = await table.list_models() from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config] assert len(from_config_models) == 0 # No from_config models when no run_config # Cleanup await table.shutdown() async def test_models_filter_persistent_models_from_removed_providers(cached_disk_dist_registry): """Test that models from removed providers are filtered out from persistent models.""" from llama_stack.apis.models import ModelType from llama_stack.core.datatypes import ModelWithOwner, Provider, RegistryEntrySource, StackRunConfig from llama_stack.core.routing_tables.models import ModelsRoutingTable # Create a routing table table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() # Create some mock persistent models model1 = ModelWithOwner( identifier="test_provider_1/model1", provider_resource_id="model1", provider_id="test_provider_1", metadata={}, model_type=ModelType.llm, source=RegistryEntrySource.listed_from_provider, ) model2 = ModelWithOwner( identifier="test_provider_2/model2", provider_resource_id="model2", provider_id="test_provider_2", metadata={}, model_type=ModelType.llm, source=RegistryEntrySource.listed_from_provider, ) user_model = ModelWithOwner( identifier="user_model", provider_resource_id="user_model", provider_id="test_provider_1", metadata={}, model_type=ModelType.llm, source=RegistryEntrySource.via_register_api, ) # Create a run config that only includes test_provider_1 (test_provider_2 is removed) run_config = StackRunConfig( image_name="test", providers={ "inference": [ Provider(provider_id="test_provider_1", provider_type="openai", config={"api_key": "test_key"}), # test_provider_2 is removed from run.yaml ] }, models=[], ) # Set the run config table.current_run_config = run_config await table.cleanup_disabled_provider_models() await table.register_from_config_models() # Test the cleanup logic directly # First, manually add models to the registry to simulate existing models await table.dist_registry.register(model1) await table.dist_registry.register(model2) await table.dist_registry.register(user_model) # Now set the run config which should trigger cleanup table.current_run_config = run_config await table.cleanup_disabled_provider_models() await table.register_from_config_models() # Get the list of models after cleanup response = await table.list_models() model_identifiers = {m.identifier for m in response.data} # Should have user_model (user-registered) and model1 (from enabled provider), but not model2 (from disabled provider) # model1 should be kept because test_provider_1 is in the run config (enabled) # model2 should be removed because test_provider_2 is not in the run config (disabled) # user_model should be kept because it's user-registered assert "user_model" in model_identifiers assert "test_provider_1/model1" in model_identifiers assert "test_provider_2/model2" not in model_identifiers # Test that user-registered models are always kept regardless of provider status user_model_found = next((m for m in response.data if m.identifier == "user_model"), None) assert user_model_found is not None assert user_model_found.source == RegistryEntrySource.via_register_api # Cleanup await table.shutdown()