diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index dc3dcf5b2..d6fa7ab6b 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -68,7 +68,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( - f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}" + f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n" + "Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'." ) provider_model_id = provider_model_id or model_id diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 30f795d33..12b05ebff 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -11,15 +11,17 @@ from unittest.mock import AsyncMock from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api -from llama_stack.apis.models import Model +from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter +from llama_stack.apis.vector_dbs import VectorDB from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable +from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: @@ -104,6 +106,17 @@ class ToolGroupsImpl(Impl): ) +class VectorDBImpl(Impl): + def __init__(self): + super().__init__(Api.vector_io) + + async def register_vector_db(self, vector_db: VectorDB): + return vector_db + + async def unregister_vector_db(self, vector_db_id: str): + return vector_db_id + + async def test_models_routing_table(cached_disk_dist_registry): table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() @@ -115,27 +128,27 @@ async def test_models_routing_table(cached_disk_dist_registry): models = await table.list_models() assert len(models.data) == 2 model_ids = {m.identifier for m in models.data} - assert "test-model" in model_ids - assert "test-model-2" in model_ids + assert "test_provider/test-model" in model_ids + assert "test_provider/test-model-2" in model_ids # Test openai list models openai_models = await table.openai_list_models() assert len(openai_models.data) == 2 openai_model_ids = {m.id for m in openai_models.data} - assert "test-model" in openai_model_ids - assert "test-model-2" in openai_model_ids + assert "test_provider/test-model" in openai_model_ids + assert "test_provider/test-model-2" in openai_model_ids # Test get_object_by_identifier - model = await table.get_object_by_identifier("model", "test-model") + model = await table.get_object_by_identifier("model", "test_provider/test-model") assert model is not None - assert model.identifier == "test-model" + assert model.identifier == "test_provider/test-model" # Test get_object_by_identifier on non-existent object non_existent = await table.get_object_by_identifier("model", "non-existent-model") assert non_existent is None - await table.unregister_model(model_id="test-model") - await table.unregister_model(model_id="test-model-2") + await table.unregister_model(model_id="test_provider/test-model") + await table.unregister_model(model_id="test_provider/test-model-2") models = await table.list_models() assert len(models.data) == 0 @@ -160,6 +173,36 @@ async def test_shields_routing_table(cached_disk_dist_registry): assert "test-shield-2" in shield_ids +async def test_vectordbs_routing_table(cached_disk_dist_registry): + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + # Register multiple vector databases and verify listing + await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") + await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") + vector_dbs = await table.list_vector_dbs() + + assert len(vector_dbs.data) == 2 + vector_db_ids = {v.identifier for v in vector_dbs.data} + assert "test-vectordb" in vector_db_ids + assert "test-vectordb-2" in vector_db_ids + + await table.unregister_vector_db(vector_db_id="test-vectordb") + await table.unregister_vector_db(vector_db_id="test-vectordb-2") + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 0 + + async def test_datasets_routing_table(cached_disk_dist_registry): table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() @@ -245,3 +288,93 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry): await table.unregister_toolgroup(toolgroup_id="test-toolgroup") tool_groups = await table.list_tool_groups() assert len(tool_groups.data) == 0 + + +async def test_models_alias_registration_and_lookup(cached_disk_dist_registry): + """Test alias registration (model_id != provider_model_id) and lookup behavior.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model with alias (model_id different from provider_model_id) + await table.register_model( + model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider" + ) + + # Verify the model was registered with alias as identifier (not namespaced) + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.identifier == "my-alias" # Uses alias as identifier + assert model.provider_resource_id == "actual-provider-model" + + # Test lookup by alias works + retrieved_model = await table.get_model("my-alias") + assert retrieved_model.identifier == "my-alias" + assert retrieved_model.provider_resource_id == "actual-provider-model" + + +async def test_models_multi_provider_disambiguation(cached_disk_dist_registry): + """Test registration and lookup with multiple providers having same provider_model_id.""" + table = ModelsRoutingTable( + {"provider1": InferenceImpl(), "provider2": InferenceImpl()}, cached_disk_dist_registry, {} + ) + await table.initialize() + + # Register same provider_model_id on both providers (no aliases) + await table.register_model(model_id="common-model", provider_id="provider1") + await table.register_model(model_id="common-model", provider_id="provider2") + + # Verify both models get namespaced identifiers + models = await table.list_models() + assert len(models.data) == 2 + identifiers = {m.identifier for m in models.data} + assert identifiers == {"provider1/common-model", "provider2/common-model"} + + # Test lookup by full namespaced identifier works + model1 = await table.get_model("provider1/common-model") + assert model1.provider_id == "provider1" + assert model1.provider_resource_id == "common-model" + + model2 = await table.get_model("provider2/common-model") + assert model2.provider_id == "provider2" + assert model2.provider_resource_id == "common-model" + + # Test lookup by unscoped provider_model_id fails with multiple providers error + try: + await table.get_model("common-model") + raise AssertionError("Should have raised ValueError for multiple providers") + except ValueError as e: + assert "Multiple providers found" in str(e) + assert "provider1" in str(e) and "provider2" in str(e) + + +async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): + """Test two-stage lookup: direct identifier hit vs fallback to provider_resource_id.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model without alias (gets namespaced identifier) + await table.register_model(model_id="test-model", provider_id="test_provider") + + # Verify namespaced identifier was created + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.identifier == "test_provider/test-model" + assert model.provider_resource_id == "test-model" + + # Test lookup by full namespaced identifier (direct hit via get_object_by_identifier) + retrieved_model = await table.get_model("test_provider/test-model") + assert retrieved_model.identifier == "test_provider/test-model" + + # Test lookup by unscoped provider_model_id (fallback via iteration) + retrieved_model = await table.get_model("test-model") + assert retrieved_model.identifier == "test_provider/test-model" + assert retrieved_model.provider_resource_id == "test-model" + + # Test lookup of non-existent model fails + try: + await table.get_model("non-existent") + raise AssertionError("Should have raised ValueError for non-existent model") + except ValueError as e: + assert "not found" in str(e)