mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 23:28:53 +00:00
fix(models)!: always prefix models with provider_id when registering (#3822)
**!!BREAKING CHANGE!!** The lookup is also straightforward -- we always look for this identifier and don't try to find a match for something without the provider_id prefix. Note that, this ideally means we need to update the `register_model()` API also (we should kill "identifier" from there) but I am not doing that as part of this PR. ## Test Plan Existing unit tests
This commit is contained in:
parent
f205ab6f6c
commit
f70aa99c97
10 changed files with 53 additions and 124 deletions
|
@ -11,6 +11,7 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.datatypes import Api
|
||||
|
@ -450,6 +451,7 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
|
|||
await table.initialize()
|
||||
|
||||
# Register model with alias (model_id different from provider_model_id)
|
||||
# NOTE: Aliases are not supported anymore, so this is a no-op
|
||||
await table.register_model(
|
||||
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
|
||||
)
|
||||
|
@ -458,12 +460,15 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
|
|||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
model = models.data[0]
|
||||
assert model.identifier == "my-alias" # Uses alias as identifier
|
||||
assert model.identifier == "test_provider/actual-provider-model"
|
||||
assert model.provider_resource_id == "actual-provider-model"
|
||||
|
||||
# Test lookup by alias works
|
||||
retrieved_model = await table.get_model("my-alias")
|
||||
assert retrieved_model.identifier == "my-alias"
|
||||
# Test lookup by alias fails
|
||||
with pytest.raises(ModelNotFoundError, match="Model 'my-alias' not found"):
|
||||
await table.get_model("my-alias")
|
||||
|
||||
retrieved_model = await table.get_model("test_provider/actual-provider-model")
|
||||
assert retrieved_model.identifier == "test_provider/actual-provider-model"
|
||||
assert retrieved_model.provider_resource_id == "actual-provider-model"
|
||||
|
||||
|
||||
|
@ -494,12 +499,8 @@ async def test_models_multi_provider_disambiguation(cached_disk_dist_registry):
|
|||
assert model2.provider_resource_id == "common-model"
|
||||
|
||||
# Test lookup by unscoped provider_model_id fails with multiple providers error
|
||||
try:
|
||||
with pytest.raises(ModelNotFoundError, match="Model 'common-model' not found"):
|
||||
await table.get_model("common-model")
|
||||
raise AssertionError("Should have raised ValueError for multiple providers")
|
||||
except ValueError as e:
|
||||
assert "Multiple providers found" in str(e)
|
||||
assert "provider1" in str(e) and "provider2" in str(e)
|
||||
|
||||
|
||||
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
||||
|
@ -522,16 +523,12 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
|||
assert retrieved_model.identifier == "test_provider/test-model"
|
||||
|
||||
# Test lookup by unscoped provider_model_id (fallback via iteration)
|
||||
retrieved_model = await table.get_model("test-model")
|
||||
assert retrieved_model.identifier == "test_provider/test-model"
|
||||
assert retrieved_model.provider_resource_id == "test-model"
|
||||
with pytest.raises(ModelNotFoundError, match="Model 'test-model' not found"):
|
||||
await table.get_model("test-model")
|
||||
|
||||
# Test lookup of non-existent model fails
|
||||
try:
|
||||
with pytest.raises(ModelNotFoundError, match="Model 'non-existent' not found"):
|
||||
await table.get_model("non-existent")
|
||||
raise AssertionError("Should have raised ValueError for non-existent model")
|
||||
except ValueError as e:
|
||||
assert "not found" in str(e)
|
||||
|
||||
|
||||
async def test_models_source_tracking_default(cached_disk_dist_registry):
|
||||
|
@ -603,7 +600,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
|
|||
assert len(models.data) == 1
|
||||
user_model = models.data[0]
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.identifier == "my-custom-alias"
|
||||
assert user_model.identifier == "test_provider/provider-model-1"
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
|
||||
# Now simulate provider refresh
|
||||
|
@ -630,7 +627,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
|
|||
assert len(models.data) == 2
|
||||
|
||||
# Find the user model and provider model
|
||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
||||
user_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-1"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
|
||||
|
||||
assert user_model is not None
|
||||
|
|
|
@ -205,7 +205,7 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
assert await mixin.check_model_availability("pre-registered-model")
|
||||
# Should not call the provider's list_models since model was found in store
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
|
||||
mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model")
|
||||
|
||||
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
|
@ -222,7 +222,7 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
assert await mixin.check_model_availability("some-mock-model-id")
|
||||
# Should call the provider's list_models since model was not found in store
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
|
||||
mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id")
|
||||
|
||||
|
||||
class TestOpenAIMixinCacheBehavior:
|
||||
|
|
|
@ -256,12 +256,12 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
|
|||
- permit:
|
||||
principal: user-2
|
||||
actions: [read]
|
||||
resource: model::model-1
|
||||
resource: model::test_provider/model-1
|
||||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
principal: user-3
|
||||
actions: [read]
|
||||
resource: model::model-2
|
||||
resource: model::test_provider/model-2
|
||||
description: user-3 has read access to model-2 only
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
|
@ -285,21 +285,15 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
|||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.register_model(
|
||||
"model-1", provider_model_id="test_provider/model-1", provider_id="test_provider"
|
||||
)
|
||||
await routing_table.register_model(
|
||||
"model-2", provider_model_id="test_provider/model-2", provider_id="test_provider"
|
||||
)
|
||||
await routing_table.register_model(
|
||||
"model-3", provider_model_id="test_provider/model-3", provider_id="test_provider"
|
||||
)
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
model = await routing_table.get_model("model-3")
|
||||
assert model.identifier == "model-3"
|
||||
await routing_table.register_model("model-1", provider_model_id="model-1", provider_id="test_provider")
|
||||
await routing_table.register_model("model-2", provider_model_id="model-2", provider_id="test_provider")
|
||||
await routing_table.register_model("model-3", provider_model_id="model-3", provider_id="test_provider")
|
||||
model = await routing_table.get_model("test_provider/model-1")
|
||||
assert model.identifier == "test_provider/model-1"
|
||||
model = await routing_table.get_model("test_provider/model-2")
|
||||
assert model.identifier == "test_provider/model-2"
|
||||
model = await routing_table.get_model("test_provider/model-3")
|
||||
assert model.identifier == "test_provider/model-3"
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-2",
|
||||
|
@ -308,16 +302,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
|||
"projects": ["foo"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
model = await routing_table.get_model("test_provider/model-1")
|
||||
assert model.identifier == "test_provider/model-1"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-2")
|
||||
await routing_table.get_model("test_provider/model-2")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
await routing_table.get_model("test_provider/model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-4", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-1")
|
||||
await routing_table.unregister_model("test_provider/model-1")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-3",
|
||||
|
@ -326,16 +320,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
|||
"projects": ["bar"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
model = await routing_table.get_model("test_provider/model-2")
|
||||
assert model.identifier == "test_provider/model-2"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-1")
|
||||
await routing_table.get_model("test_provider/model-1")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
await routing_table.get_model("test_provider/model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-5", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-2")
|
||||
await routing_table.unregister_model("test_provider/model-2")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-1",
|
||||
|
@ -344,9 +338,9 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
|||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.unregister_model("model-3")
|
||||
await routing_table.unregister_model("test_provider/model-3")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
await routing_table.get_model("test_provider/model-3")
|
||||
|
||||
|
||||
def test_permit_when():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue