fix(models)!: always prefix models with provider_id when registering (#3822)

**!!BREAKING CHANGE!!**

The lookup is also straightforward -- we always look for this identifier
and don't try to find a match for something without the provider_id
prefix.

Note that, this ideally means we need to update the `register_model()`
API also (we should kill "identifier" from there) but I am not doing
that as part of this PR.

## Test Plan

Existing unit tests
This commit is contained in:
Ashwin Bharambe 2025-10-16 06:47:39 -07:00 committed by GitHub
parent f205ab6f6c
commit f70aa99c97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 53 additions and 124 deletions

View file

@ -12,26 +12,7 @@
"body": {
"__type__": "ollama._types.ProcessResponse",
"__data__": {
"models": [
{
"model": "llama-guard3:1b",
"name": "llama-guard3:1b",
"digest": "494147e06bf99e10dbe67b63a07ac81c162f18ef3341aa3390007ac828571b3b",
"expires_at": "2025-10-13T14:07:12.309717-07:00",
"size": 2279663616,
"size_vram": 2279663616,
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": [
"llama"
],
"parameter_size": "1.5B",
"quantization_level": "Q8_0"
}
}
]
"models": []
}
},
"is_streaming": false

View file

@ -117,42 +117,24 @@ def client_with_models(
text_model_id,
vision_model_id,
embedding_model_id,
embedding_dimension,
judge_model_id,
):
client = llama_stack_client
providers = [p for p in client.providers.list() if p.api == "inference"]
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
model_ids = {m.identifier for m in client.models.list()}
model_ids.update(m.provider_resource_id for m in client.models.list())
# TODO: fix this crap where we use the first provider randomly
# that cannot be right. I think the test should just specify the provider_id
if text_model_id and text_model_id not in model_ids:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
raise ValueError(f"text_model_id {text_model_id} not found")
if vision_model_id and vision_model_id not in model_ids:
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
raise ValueError(f"vision_model_id {vision_model_id} not found")
if judge_model_id and judge_model_id not in model_ids:
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
raise ValueError(f"judge_model_id {judge_model_id} not found")
if embedding_model_id and embedding_model_id not in model_ids:
# try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None
for p in providers:
if p.provider_type == "inline::sentence-transformers":
selected_provider = p
break
selected_provider = selected_provider or providers[0]
client.models.register(
model_id=embedding_model_id,
provider_id=selected_provider.provider_id,
model_type="embedding",
metadata={"embedding_dimension": embedding_dimension or 768},
)
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
return client

View file

@ -11,6 +11,7 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
@ -450,6 +451,7 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
await table.initialize()
# Register model with alias (model_id different from provider_model_id)
# NOTE: Aliases are not supported anymore, so this is a no-op
await table.register_model(
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
)
@ -458,12 +460,15 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "my-alias" # Uses alias as identifier
assert model.identifier == "test_provider/actual-provider-model"
assert model.provider_resource_id == "actual-provider-model"
# Test lookup by alias works
retrieved_model = await table.get_model("my-alias")
assert retrieved_model.identifier == "my-alias"
# Test lookup by alias fails
with pytest.raises(ModelNotFoundError, match="Model 'my-alias' not found"):
await table.get_model("my-alias")
retrieved_model = await table.get_model("test_provider/actual-provider-model")
assert retrieved_model.identifier == "test_provider/actual-provider-model"
assert retrieved_model.provider_resource_id == "actual-provider-model"
@ -494,12 +499,8 @@ async def test_models_multi_provider_disambiguation(cached_disk_dist_registry):
assert model2.provider_resource_id == "common-model"
# Test lookup by unscoped provider_model_id fails with multiple providers error
try:
with pytest.raises(ModelNotFoundError, match="Model 'common-model' not found"):
await table.get_model("common-model")
raise AssertionError("Should have raised ValueError for multiple providers")
except ValueError as e:
assert "Multiple providers found" in str(e)
assert "provider1" in str(e) and "provider2" in str(e)
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
@ -522,16 +523,12 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
assert retrieved_model.identifier == "test_provider/test-model"
# Test lookup by unscoped provider_model_id (fallback via iteration)
retrieved_model = await table.get_model("test-model")
assert retrieved_model.identifier == "test_provider/test-model"
assert retrieved_model.provider_resource_id == "test-model"
with pytest.raises(ModelNotFoundError, match="Model 'test-model' not found"):
await table.get_model("test-model")
# Test lookup of non-existent model fails
try:
with pytest.raises(ModelNotFoundError, match="Model 'non-existent' not found"):
await table.get_model("non-existent")
raise AssertionError("Should have raised ValueError for non-existent model")
except ValueError as e:
assert "not found" in str(e)
async def test_models_source_tracking_default(cached_disk_dist_registry):
@ -603,7 +600,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
assert len(models.data) == 1
user_model = models.data[0]
assert user_model.source == RegistryEntrySource.via_register_api
assert user_model.identifier == "my-custom-alias"
assert user_model.identifier == "test_provider/provider-model-1"
assert user_model.provider_resource_id == "provider-model-1"
# Now simulate provider refresh
@ -630,7 +627,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
assert len(models.data) == 2
# Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
user_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-1"), None)
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
assert user_model is not None

View file

@ -205,7 +205,7 @@ class TestOpenAIMixinCheckModelAvailability:
assert await mixin.check_model_availability("pre-registered-model")
# Should not call the provider's list_models since model was found in store
mock_client_with_models.models.list.assert_not_called()
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model")
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
self, mixin, mock_client_with_models, mock_client_context
@ -222,7 +222,7 @@ class TestOpenAIMixinCheckModelAvailability:
assert await mixin.check_model_availability("some-mock-model-id")
# Should call the provider's list_models since model was not found in store
mock_client_with_models.models.list.assert_called_once()
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id")
class TestOpenAIMixinCacheBehavior:

View file

@ -256,12 +256,12 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
- permit:
principal: user-2
actions: [read]
resource: model::model-1
resource: model::test_provider/model-1
description: user-2 has read access to model-1 only
- permit:
principal: user-3
actions: [read]
resource: model::model-2
resource: model::test_provider/model-2
description: user-3 has read access to model-2 only
- forbid:
actions: [create, read, delete]
@ -285,21 +285,15 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo", "bar"],
},
)
await routing_table.register_model(
"model-1", provider_model_id="test_provider/model-1", provider_id="test_provider"
)
await routing_table.register_model(
"model-2", provider_model_id="test_provider/model-2", provider_id="test_provider"
)
await routing_table.register_model(
"model-3", provider_model_id="test_provider/model-3", provider_id="test_provider"
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
await routing_table.register_model("model-1", provider_model_id="model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_model_id="model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_model_id="model-3", provider_id="test_provider")
model = await routing_table.get_model("test_provider/model-1")
assert model.identifier == "test_provider/model-1"
model = await routing_table.get_model("test_provider/model-2")
assert model.identifier == "test_provider/model-2"
model = await routing_table.get_model("test_provider/model-3")
assert model.identifier == "test_provider/model-3"
mock_get_authenticated_user.return_value = User(
"user-2",
@ -308,16 +302,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo"],
},
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("test_provider/model-1")
assert model.identifier == "test_provider/model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
await routing_table.get_model("test_provider/model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
await routing_table.get_model("test_provider/model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1")
await routing_table.unregister_model("test_provider/model-1")
mock_get_authenticated_user.return_value = User(
"user-3",
@ -326,16 +320,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["bar"],
},
)
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("test_provider/model-2")
assert model.identifier == "test_provider/model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
await routing_table.get_model("test_provider/model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
await routing_table.get_model("test_provider/model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2")
await routing_table.unregister_model("test_provider/model-2")
mock_get_authenticated_user.return_value = User(
"user-1",
@ -344,9 +338,9 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo", "bar"],
},
)
await routing_table.unregister_model("model-3")
await routing_table.unregister_model("test_provider/model-3")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
await routing_table.get_model("test_provider/model-3")
def test_permit_when():