From f70aa99c97381ceef0f171534167fd9ca3a35b0d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 16 Oct 2025 06:47:39 -0700 Subject: [PATCH] fix(models)!: always prefix models with provider_id when registering (#3822) **!!BREAKING CHANGE!!** The lookup is also straightforward -- we always look for this identifier and don't try to find a match for something without the provider_id prefix. Note that, this ideally means we need to update the `register_model()` API also (we should kill "identifier" from there) but I am not doing that as part of this PR. ## Test Plan Existing unit tests --- .../workflows/integration-vector-io-tests.yml | 2 +- llama_stack/core/routing_tables/common.py | 22 +------- llama_stack/core/routing_tables/models.py | 12 +---- .../providers/utils/inference/openai_mixin.py | 3 +- scripts/integration-tests.sh | 2 +- ...54792b9f22d2cb4522eab802810be8672d3dc.json | 21 +------- tests/integration/fixtures/common.py | 26 ++-------- .../routers/test_routing_tables.py | 33 ++++++------ .../utils/inference/test_openai_mixin.py | 4 +- tests/unit/server/test_access_control.py | 52 ++++++++----------- 10 files changed, 53 insertions(+), 124 deletions(-) diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 9dd0e260b..fe5785c73 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -170,7 +170,7 @@ jobs: uv run --no-sync \ pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ - --embedding-model nomic-ai/nomic-embed-text-v1.5 \ + --embedding-model inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \ --embedding-dimension 768 - name: Check Storage and Memory Available After Tests diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 0b5aa7843..8df0a89a9 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -245,25 +245,7 @@ class CommonRoutingTableImpl(RoutingTable): async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model: - # first try to get the model by identifier - # this works if model_id is an alias or is of the form provider_id/provider_model_id model = await routing_table.get_object_by_identifier("model", model_id) - if model is not None: - return model - - logger.warning( - f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to " - "searching in all providers. This is only for backwards compatibility and will stop working " - "soon. Migrate your calls to use fully scoped `provider_id/model_id` names." - ) - # if not found, this means model_id is an unscoped provider_model_id, we need - # to iterate (given a lack of an efficient index on the KVStore) - models = await routing_table.get_all_with_type("model") - matching_models = [m for m in models if m.provider_resource_id == model_id] - if len(matching_models) == 0: + if not model: raise ModelNotFoundError(model_id) - - if len(matching_models) > 1: - raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}") - - return matching_models[0] + return model diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 716be936a..7e43d7273 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): try: models = await provider.list_models() except Exception as e: - logger.debug(f"Model refresh failed for provider {provider_id}: {e}") + logger.warning(f"Model refresh failed for provider {provider_id}: {e}") continue self.listed_providers.add(provider_id) @@ -104,15 +104,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") - # an identifier different than provider_model_id implies it is an alias, so that - # becomes the globally unique identifier. otherwise provider_model_ids can conflict, - # so as a general rule we must use the provider_id to disambiguate. - - if model_id != provider_model_id: - identifier = model_id - else: - identifier = f"{provider_id}/{provider_model_id}" - + identifier = f"{provider_id}/{provider_model_id}" model = ModelWithOwner( identifier=identifier, provider_resource_id=provider_model_id, diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 11c0b6829..dc397aa76 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -435,7 +435,8 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ # First check if the model is pre-registered in the model store if hasattr(self, "model_store") and self.model_store: - if await self.model_store.has_model(model): + qualified_model = f"{self.__provider_id__}/{model}" # type: ignore[attr-defined] + if await self.model_store.has_model(qualified_model): return True # Then check the provider's dynamic model cache diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index f3dc32745..138f1d144 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -290,7 +290,7 @@ pytest -s -v $PYTEST_TARGET \ -k "$PYTEST_PATTERN" \ $EXTRA_PARAMS \ --color=yes \ - --embedding-model=nomic-ai/nomic-embed-text-v1.5 \ + --embedding-model=sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \ --color=yes $EXTRA_PARAMS \ --capture=tee-sys exit_code=$? diff --git a/tests/integration/common/recordings/02c93bb3c314427bae2b7a7a6f054792b9f22d2cb4522eab802810be8672d3dc.json b/tests/integration/common/recordings/02c93bb3c314427bae2b7a7a6f054792b9f22d2cb4522eab802810be8672d3dc.json index 4ea0ee13f..2b2afeee4 100644 --- a/tests/integration/common/recordings/02c93bb3c314427bae2b7a7a6f054792b9f22d2cb4522eab802810be8672d3dc.json +++ b/tests/integration/common/recordings/02c93bb3c314427bae2b7a7a6f054792b9f22d2cb4522eab802810be8672d3dc.json @@ -12,26 +12,7 @@ "body": { "__type__": "ollama._types.ProcessResponse", "__data__": { - "models": [ - { - "model": "llama-guard3:1b", - "name": "llama-guard3:1b", - "digest": "494147e06bf99e10dbe67b63a07ac81c162f18ef3341aa3390007ac828571b3b", - "expires_at": "2025-10-13T14:07:12.309717-07:00", - "size": 2279663616, - "size_vram": 2279663616, - "details": { - "parent_model": "", - "format": "gguf", - "family": "llama", - "families": [ - "llama" - ], - "parameter_size": "1.5B", - "quantization_level": "Q8_0" - } - } - ] + "models": [] } }, "is_streaming": false diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 6ebf0aed7..68a30fc69 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -117,42 +117,24 @@ def client_with_models( text_model_id, vision_model_id, embedding_model_id, - embedding_dimension, judge_model_id, ): client = llama_stack_client providers = [p for p in client.providers.list() if p.api == "inference"] assert len(providers) > 0, "No inference providers found" - inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] model_ids = {m.identifier for m in client.models.list()} - model_ids.update(m.provider_resource_id for m in client.models.list()) - # TODO: fix this crap where we use the first provider randomly - # that cannot be right. I think the test should just specify the provider_id if text_model_id and text_model_id not in model_ids: - client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) + raise ValueError(f"text_model_id {text_model_id} not found") if vision_model_id and vision_model_id not in model_ids: - client.models.register(model_id=vision_model_id, provider_id=inference_providers[0]) + raise ValueError(f"vision_model_id {vision_model_id} not found") if judge_model_id and judge_model_id not in model_ids: - client.models.register(model_id=judge_model_id, provider_id=inference_providers[0]) + raise ValueError(f"judge_model_id {judge_model_id} not found") if embedding_model_id and embedding_model_id not in model_ids: - # try to find a provider that supports embeddings, if sentence-transformers is not available - selected_provider = None - for p in providers: - if p.provider_type == "inline::sentence-transformers": - selected_provider = p - break - - selected_provider = selected_provider or providers[0] - client.models.register( - model_id=embedding_model_id, - provider_id=selected_provider.provider_id, - model_type="embedding", - metadata={"embedding_dimension": embedding_dimension or 768}, - ) + raise ValueError(f"embedding_model_id {embedding_model_id} not found") return client diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index beb0b4a95..87ebcef00 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -11,6 +11,7 @@ from unittest.mock import AsyncMock import pytest from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api @@ -450,6 +451,7 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry): await table.initialize() # Register model with alias (model_id different from provider_model_id) + # NOTE: Aliases are not supported anymore, so this is a no-op await table.register_model( model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider" ) @@ -458,12 +460,15 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry): models = await table.list_models() assert len(models.data) == 1 model = models.data[0] - assert model.identifier == "my-alias" # Uses alias as identifier + assert model.identifier == "test_provider/actual-provider-model" assert model.provider_resource_id == "actual-provider-model" - # Test lookup by alias works - retrieved_model = await table.get_model("my-alias") - assert retrieved_model.identifier == "my-alias" + # Test lookup by alias fails + with pytest.raises(ModelNotFoundError, match="Model 'my-alias' not found"): + await table.get_model("my-alias") + + retrieved_model = await table.get_model("test_provider/actual-provider-model") + assert retrieved_model.identifier == "test_provider/actual-provider-model" assert retrieved_model.provider_resource_id == "actual-provider-model" @@ -494,12 +499,8 @@ async def test_models_multi_provider_disambiguation(cached_disk_dist_registry): assert model2.provider_resource_id == "common-model" # Test lookup by unscoped provider_model_id fails with multiple providers error - try: + with pytest.raises(ModelNotFoundError, match="Model 'common-model' not found"): await table.get_model("common-model") - raise AssertionError("Should have raised ValueError for multiple providers") - except ValueError as e: - assert "Multiple providers found" in str(e) - assert "provider1" in str(e) and "provider2" in str(e) async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): @@ -522,16 +523,12 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): assert retrieved_model.identifier == "test_provider/test-model" # Test lookup by unscoped provider_model_id (fallback via iteration) - retrieved_model = await table.get_model("test-model") - assert retrieved_model.identifier == "test_provider/test-model" - assert retrieved_model.provider_resource_id == "test-model" + with pytest.raises(ModelNotFoundError, match="Model 'test-model' not found"): + await table.get_model("test-model") # Test lookup of non-existent model fails - try: + with pytest.raises(ModelNotFoundError, match="Model 'non-existent' not found"): await table.get_model("non-existent") - raise AssertionError("Should have raised ValueError for non-existent model") - except ValueError as e: - assert "not found" in str(e) async def test_models_source_tracking_default(cached_disk_dist_registry): @@ -603,7 +600,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi assert len(models.data) == 1 user_model = models.data[0] assert user_model.source == RegistryEntrySource.via_register_api - assert user_model.identifier == "my-custom-alias" + assert user_model.identifier == "test_provider/provider-model-1" assert user_model.provider_resource_id == "provider-model-1" # Now simulate provider refresh @@ -630,7 +627,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi assert len(models.data) == 2 # Find the user model and provider model - user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) + user_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-1"), None) provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None) assert user_model is not None diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 80c219055..78241bc22 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -205,7 +205,7 @@ class TestOpenAIMixinCheckModelAvailability: assert await mixin.check_model_availability("pre-registered-model") # Should not call the provider's list_models since model was found in store mock_client_with_models.models.list.assert_not_called() - mock_model_store.has_model.assert_called_once_with("pre-registered-model") + mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model") async def test_check_model_availability_fallback_to_provider_when_not_in_store( self, mixin, mock_client_with_models, mock_client_context @@ -222,7 +222,7 @@ class TestOpenAIMixinCheckModelAvailability: assert await mixin.check_model_availability("some-mock-model-id") # Should call the provider's list_models since model was not found in store mock_client_with_models.models.list.assert_called_once() - mock_model_store.has_model.assert_called_once_with("some-mock-model-id") + mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id") class TestOpenAIMixinCacheBehavior: diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 55449804a..ea4f9b8b2 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -256,12 +256,12 @@ async def test_setup_with_access_policy(cached_disk_dist_registry): - permit: principal: user-2 actions: [read] - resource: model::model-1 + resource: model::test_provider/model-1 description: user-2 has read access to model-1 only - permit: principal: user-3 actions: [read] - resource: model::model-2 + resource: model::test_provider/model-2 description: user-3 has read access to model-2 only - forbid: actions: [create, read, delete] @@ -285,21 +285,15 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access "projects": ["foo", "bar"], }, ) - await routing_table.register_model( - "model-1", provider_model_id="test_provider/model-1", provider_id="test_provider" - ) - await routing_table.register_model( - "model-2", provider_model_id="test_provider/model-2", provider_id="test_provider" - ) - await routing_table.register_model( - "model-3", provider_model_id="test_provider/model-3", provider_id="test_provider" - ) - model = await routing_table.get_model("model-1") - assert model.identifier == "model-1" - model = await routing_table.get_model("model-2") - assert model.identifier == "model-2" - model = await routing_table.get_model("model-3") - assert model.identifier == "model-3" + await routing_table.register_model("model-1", provider_model_id="model-1", provider_id="test_provider") + await routing_table.register_model("model-2", provider_model_id="model-2", provider_id="test_provider") + await routing_table.register_model("model-3", provider_model_id="model-3", provider_id="test_provider") + model = await routing_table.get_model("test_provider/model-1") + assert model.identifier == "test_provider/model-1" + model = await routing_table.get_model("test_provider/model-2") + assert model.identifier == "test_provider/model-2" + model = await routing_table.get_model("test_provider/model-3") + assert model.identifier == "test_provider/model-3" mock_get_authenticated_user.return_value = User( "user-2", @@ -308,16 +302,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access "projects": ["foo"], }, ) - model = await routing_table.get_model("model-1") - assert model.identifier == "model-1" + model = await routing_table.get_model("test_provider/model-1") + assert model.identifier == "test_provider/model-1" with pytest.raises(ValueError): - await routing_table.get_model("model-2") + await routing_table.get_model("test_provider/model-2") with pytest.raises(ValueError): - await routing_table.get_model("model-3") + await routing_table.get_model("test_provider/model-3") with pytest.raises(AccessDeniedError): await routing_table.register_model("model-4", provider_id="test_provider") with pytest.raises(AccessDeniedError): - await routing_table.unregister_model("model-1") + await routing_table.unregister_model("test_provider/model-1") mock_get_authenticated_user.return_value = User( "user-3", @@ -326,16 +320,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access "projects": ["bar"], }, ) - model = await routing_table.get_model("model-2") - assert model.identifier == "model-2" + model = await routing_table.get_model("test_provider/model-2") + assert model.identifier == "test_provider/model-2" with pytest.raises(ValueError): - await routing_table.get_model("model-1") + await routing_table.get_model("test_provider/model-1") with pytest.raises(ValueError): - await routing_table.get_model("model-3") + await routing_table.get_model("test_provider/model-3") with pytest.raises(AccessDeniedError): await routing_table.register_model("model-5", provider_id="test_provider") with pytest.raises(AccessDeniedError): - await routing_table.unregister_model("model-2") + await routing_table.unregister_model("test_provider/model-2") mock_get_authenticated_user.return_value = User( "user-1", @@ -344,9 +338,9 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access "projects": ["foo", "bar"], }, ) - await routing_table.unregister_model("model-3") + await routing_table.unregister_model("test_provider/model-3") with pytest.raises(ValueError): - await routing_table.get_model("model-3") + await routing_table.get_model("test_provider/model-3") def test_permit_when():