mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(models)!: always prefix models with provider_id when registering (#3822)
**!!BREAKING CHANGE!!** The lookup is also straightforward -- we always look for this identifier and don't try to find a match for something without the provider_id prefix. Note that, this ideally means we need to update the `register_model()` API also (we should kill "identifier" from there) but I am not doing that as part of this PR. ## Test Plan Existing unit tests
This commit is contained in:
parent
f205ab6f6c
commit
f70aa99c97
10 changed files with 53 additions and 124 deletions
|
|
@ -170,7 +170,7 @@ jobs:
|
||||||
uv run --no-sync \
|
uv run --no-sync \
|
||||||
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||||
tests/integration/vector_io \
|
tests/integration/vector_io \
|
||||||
--embedding-model nomic-ai/nomic-embed-text-v1.5 \
|
--embedding-model inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \
|
||||||
--embedding-dimension 768
|
--embedding-dimension 768
|
||||||
|
|
||||||
- name: Check Storage and Memory Available After Tests
|
- name: Check Storage and Memory Available After Tests
|
||||||
|
|
|
||||||
|
|
@ -245,25 +245,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
|
|
||||||
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
||||||
# first try to get the model by identifier
|
|
||||||
# this works if model_id is an alias or is of the form provider_id/provider_model_id
|
|
||||||
model = await routing_table.get_object_by_identifier("model", model_id)
|
model = await routing_table.get_object_by_identifier("model", model_id)
|
||||||
if model is not None:
|
if not model:
|
||||||
return model
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
|
||||||
"searching in all providers. This is only for backwards compatibility and will stop working "
|
|
||||||
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
|
|
||||||
)
|
|
||||||
# if not found, this means model_id is an unscoped provider_model_id, we need
|
|
||||||
# to iterate (given a lack of an efficient index on the KVStore)
|
|
||||||
models = await routing_table.get_all_with_type("model")
|
|
||||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
|
||||||
if len(matching_models) == 0:
|
|
||||||
raise ModelNotFoundError(model_id)
|
raise ModelNotFoundError(model_id)
|
||||||
|
return model
|
||||||
if len(matching_models) > 1:
|
|
||||||
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
|
||||||
|
|
||||||
return matching_models[0]
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
try:
|
try:
|
||||||
models = await provider.list_models()
|
models = await provider.list_models()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
|
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.listed_providers.add(provider_id)
|
self.listed_providers.add(provider_id)
|
||||||
|
|
@ -104,15 +104,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
|
|
||||||
# an identifier different than provider_model_id implies it is an alias, so that
|
identifier = f"{provider_id}/{provider_model_id}"
|
||||||
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
|
||||||
# so as a general rule we must use the provider_id to disambiguate.
|
|
||||||
|
|
||||||
if model_id != provider_model_id:
|
|
||||||
identifier = model_id
|
|
||||||
else:
|
|
||||||
identifier = f"{provider_id}/{provider_model_id}"
|
|
||||||
|
|
||||||
model = ModelWithOwner(
|
model = ModelWithOwner(
|
||||||
identifier=identifier,
|
identifier=identifier,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
|
|
|
||||||
|
|
@ -435,7 +435,8 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
# First check if the model is pre-registered in the model store
|
# First check if the model is pre-registered in the model store
|
||||||
if hasattr(self, "model_store") and self.model_store:
|
if hasattr(self, "model_store") and self.model_store:
|
||||||
if await self.model_store.has_model(model):
|
qualified_model = f"{self.__provider_id__}/{model}" # type: ignore[attr-defined]
|
||||||
|
if await self.model_store.has_model(qualified_model):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Then check the provider's dynamic model cache
|
# Then check the provider's dynamic model cache
|
||||||
|
|
|
||||||
|
|
@ -290,7 +290,7 @@ pytest -s -v $PYTEST_TARGET \
|
||||||
-k "$PYTEST_PATTERN" \
|
-k "$PYTEST_PATTERN" \
|
||||||
$EXTRA_PARAMS \
|
$EXTRA_PARAMS \
|
||||||
--color=yes \
|
--color=yes \
|
||||||
--embedding-model=nomic-ai/nomic-embed-text-v1.5 \
|
--embedding-model=sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \
|
||||||
--color=yes $EXTRA_PARAMS \
|
--color=yes $EXTRA_PARAMS \
|
||||||
--capture=tee-sys
|
--capture=tee-sys
|
||||||
exit_code=$?
|
exit_code=$?
|
||||||
|
|
|
||||||
|
|
@ -12,26 +12,7 @@
|
||||||
"body": {
|
"body": {
|
||||||
"__type__": "ollama._types.ProcessResponse",
|
"__type__": "ollama._types.ProcessResponse",
|
||||||
"__data__": {
|
"__data__": {
|
||||||
"models": [
|
"models": []
|
||||||
{
|
|
||||||
"model": "llama-guard3:1b",
|
|
||||||
"name": "llama-guard3:1b",
|
|
||||||
"digest": "494147e06bf99e10dbe67b63a07ac81c162f18ef3341aa3390007ac828571b3b",
|
|
||||||
"expires_at": "2025-10-13T14:07:12.309717-07:00",
|
|
||||||
"size": 2279663616,
|
|
||||||
"size_vram": 2279663616,
|
|
||||||
"details": {
|
|
||||||
"parent_model": "",
|
|
||||||
"format": "gguf",
|
|
||||||
"family": "llama",
|
|
||||||
"families": [
|
|
||||||
"llama"
|
|
||||||
],
|
|
||||||
"parameter_size": "1.5B",
|
|
||||||
"quantization_level": "Q8_0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"is_streaming": false
|
"is_streaming": false
|
||||||
|
|
|
||||||
|
|
@ -117,42 +117,24 @@ def client_with_models(
|
||||||
text_model_id,
|
text_model_id,
|
||||||
vision_model_id,
|
vision_model_id,
|
||||||
embedding_model_id,
|
embedding_model_id,
|
||||||
embedding_dimension,
|
|
||||||
judge_model_id,
|
judge_model_id,
|
||||||
):
|
):
|
||||||
client = llama_stack_client
|
client = llama_stack_client
|
||||||
|
|
||||||
providers = [p for p in client.providers.list() if p.api == "inference"]
|
providers = [p for p in client.providers.list() if p.api == "inference"]
|
||||||
assert len(providers) > 0, "No inference providers found"
|
assert len(providers) > 0, "No inference providers found"
|
||||||
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
|
|
||||||
|
|
||||||
model_ids = {m.identifier for m in client.models.list()}
|
model_ids = {m.identifier for m in client.models.list()}
|
||||||
model_ids.update(m.provider_resource_id for m in client.models.list())
|
|
||||||
|
|
||||||
# TODO: fix this crap where we use the first provider randomly
|
|
||||||
# that cannot be right. I think the test should just specify the provider_id
|
|
||||||
if text_model_id and text_model_id not in model_ids:
|
if text_model_id and text_model_id not in model_ids:
|
||||||
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
raise ValueError(f"text_model_id {text_model_id} not found")
|
||||||
if vision_model_id and vision_model_id not in model_ids:
|
if vision_model_id and vision_model_id not in model_ids:
|
||||||
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
|
raise ValueError(f"vision_model_id {vision_model_id} not found")
|
||||||
if judge_model_id and judge_model_id not in model_ids:
|
if judge_model_id and judge_model_id not in model_ids:
|
||||||
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
|
raise ValueError(f"judge_model_id {judge_model_id} not found")
|
||||||
|
|
||||||
if embedding_model_id and embedding_model_id not in model_ids:
|
if embedding_model_id and embedding_model_id not in model_ids:
|
||||||
# try to find a provider that supports embeddings, if sentence-transformers is not available
|
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
|
||||||
selected_provider = None
|
|
||||||
for p in providers:
|
|
||||||
if p.provider_type == "inline::sentence-transformers":
|
|
||||||
selected_provider = p
|
|
||||||
break
|
|
||||||
|
|
||||||
selected_provider = selected_provider or providers[0]
|
|
||||||
client.models.register(
|
|
||||||
model_id=embedding_model_id,
|
|
||||||
provider_id=selected_provider.provider_id,
|
|
||||||
model_type="embedding",
|
|
||||||
metadata={"embedding_dimension": embedding_dimension or 768},
|
|
||||||
)
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from unittest.mock import AsyncMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
|
|
@ -450,6 +451,7 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register model with alias (model_id different from provider_model_id)
|
# Register model with alias (model_id different from provider_model_id)
|
||||||
|
# NOTE: Aliases are not supported anymore, so this is a no-op
|
||||||
await table.register_model(
|
await table.register_model(
|
||||||
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
|
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
|
||||||
)
|
)
|
||||||
|
|
@ -458,12 +460,15 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
|
||||||
models = await table.list_models()
|
models = await table.list_models()
|
||||||
assert len(models.data) == 1
|
assert len(models.data) == 1
|
||||||
model = models.data[0]
|
model = models.data[0]
|
||||||
assert model.identifier == "my-alias" # Uses alias as identifier
|
assert model.identifier == "test_provider/actual-provider-model"
|
||||||
assert model.provider_resource_id == "actual-provider-model"
|
assert model.provider_resource_id == "actual-provider-model"
|
||||||
|
|
||||||
# Test lookup by alias works
|
# Test lookup by alias fails
|
||||||
retrieved_model = await table.get_model("my-alias")
|
with pytest.raises(ModelNotFoundError, match="Model 'my-alias' not found"):
|
||||||
assert retrieved_model.identifier == "my-alias"
|
await table.get_model("my-alias")
|
||||||
|
|
||||||
|
retrieved_model = await table.get_model("test_provider/actual-provider-model")
|
||||||
|
assert retrieved_model.identifier == "test_provider/actual-provider-model"
|
||||||
assert retrieved_model.provider_resource_id == "actual-provider-model"
|
assert retrieved_model.provider_resource_id == "actual-provider-model"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -494,12 +499,8 @@ async def test_models_multi_provider_disambiguation(cached_disk_dist_registry):
|
||||||
assert model2.provider_resource_id == "common-model"
|
assert model2.provider_resource_id == "common-model"
|
||||||
|
|
||||||
# Test lookup by unscoped provider_model_id fails with multiple providers error
|
# Test lookup by unscoped provider_model_id fails with multiple providers error
|
||||||
try:
|
with pytest.raises(ModelNotFoundError, match="Model 'common-model' not found"):
|
||||||
await table.get_model("common-model")
|
await table.get_model("common-model")
|
||||||
raise AssertionError("Should have raised ValueError for multiple providers")
|
|
||||||
except ValueError as e:
|
|
||||||
assert "Multiple providers found" in str(e)
|
|
||||||
assert "provider1" in str(e) and "provider2" in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
||||||
|
|
@ -522,16 +523,12 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
||||||
assert retrieved_model.identifier == "test_provider/test-model"
|
assert retrieved_model.identifier == "test_provider/test-model"
|
||||||
|
|
||||||
# Test lookup by unscoped provider_model_id (fallback via iteration)
|
# Test lookup by unscoped provider_model_id (fallback via iteration)
|
||||||
retrieved_model = await table.get_model("test-model")
|
with pytest.raises(ModelNotFoundError, match="Model 'test-model' not found"):
|
||||||
assert retrieved_model.identifier == "test_provider/test-model"
|
await table.get_model("test-model")
|
||||||
assert retrieved_model.provider_resource_id == "test-model"
|
|
||||||
|
|
||||||
# Test lookup of non-existent model fails
|
# Test lookup of non-existent model fails
|
||||||
try:
|
with pytest.raises(ModelNotFoundError, match="Model 'non-existent' not found"):
|
||||||
await table.get_model("non-existent")
|
await table.get_model("non-existent")
|
||||||
raise AssertionError("Should have raised ValueError for non-existent model")
|
|
||||||
except ValueError as e:
|
|
||||||
assert "not found" in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_models_source_tracking_default(cached_disk_dist_registry):
|
async def test_models_source_tracking_default(cached_disk_dist_registry):
|
||||||
|
|
@ -603,7 +600,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
|
||||||
assert len(models.data) == 1
|
assert len(models.data) == 1
|
||||||
user_model = models.data[0]
|
user_model = models.data[0]
|
||||||
assert user_model.source == RegistryEntrySource.via_register_api
|
assert user_model.source == RegistryEntrySource.via_register_api
|
||||||
assert user_model.identifier == "my-custom-alias"
|
assert user_model.identifier == "test_provider/provider-model-1"
|
||||||
assert user_model.provider_resource_id == "provider-model-1"
|
assert user_model.provider_resource_id == "provider-model-1"
|
||||||
|
|
||||||
# Now simulate provider refresh
|
# Now simulate provider refresh
|
||||||
|
|
@ -630,7 +627,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
|
||||||
assert len(models.data) == 2
|
assert len(models.data) == 2
|
||||||
|
|
||||||
# Find the user model and provider model
|
# Find the user model and provider model
|
||||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
user_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-1"), None)
|
||||||
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
|
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
|
||||||
|
|
||||||
assert user_model is not None
|
assert user_model is not None
|
||||||
|
|
|
||||||
|
|
@ -205,7 +205,7 @@ class TestOpenAIMixinCheckModelAvailability:
|
||||||
assert await mixin.check_model_availability("pre-registered-model")
|
assert await mixin.check_model_availability("pre-registered-model")
|
||||||
# Should not call the provider's list_models since model was found in store
|
# Should not call the provider's list_models since model was found in store
|
||||||
mock_client_with_models.models.list.assert_not_called()
|
mock_client_with_models.models.list.assert_not_called()
|
||||||
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
|
mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model")
|
||||||
|
|
||||||
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
||||||
self, mixin, mock_client_with_models, mock_client_context
|
self, mixin, mock_client_with_models, mock_client_context
|
||||||
|
|
@ -222,7 +222,7 @@ class TestOpenAIMixinCheckModelAvailability:
|
||||||
assert await mixin.check_model_availability("some-mock-model-id")
|
assert await mixin.check_model_availability("some-mock-model-id")
|
||||||
# Should call the provider's list_models since model was not found in store
|
# Should call the provider's list_models since model was not found in store
|
||||||
mock_client_with_models.models.list.assert_called_once()
|
mock_client_with_models.models.list.assert_called_once()
|
||||||
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
|
mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id")
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIMixinCacheBehavior:
|
class TestOpenAIMixinCacheBehavior:
|
||||||
|
|
|
||||||
|
|
@ -256,12 +256,12 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
|
||||||
- permit:
|
- permit:
|
||||||
principal: user-2
|
principal: user-2
|
||||||
actions: [read]
|
actions: [read]
|
||||||
resource: model::model-1
|
resource: model::test_provider/model-1
|
||||||
description: user-2 has read access to model-1 only
|
description: user-2 has read access to model-1 only
|
||||||
- permit:
|
- permit:
|
||||||
principal: user-3
|
principal: user-3
|
||||||
actions: [read]
|
actions: [read]
|
||||||
resource: model::model-2
|
resource: model::test_provider/model-2
|
||||||
description: user-3 has read access to model-2 only
|
description: user-3 has read access to model-2 only
|
||||||
- forbid:
|
- forbid:
|
||||||
actions: [create, read, delete]
|
actions: [create, read, delete]
|
||||||
|
|
@ -285,21 +285,15 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
||||||
"projects": ["foo", "bar"],
|
"projects": ["foo", "bar"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await routing_table.register_model(
|
await routing_table.register_model("model-1", provider_model_id="model-1", provider_id="test_provider")
|
||||||
"model-1", provider_model_id="test_provider/model-1", provider_id="test_provider"
|
await routing_table.register_model("model-2", provider_model_id="model-2", provider_id="test_provider")
|
||||||
)
|
await routing_table.register_model("model-3", provider_model_id="model-3", provider_id="test_provider")
|
||||||
await routing_table.register_model(
|
model = await routing_table.get_model("test_provider/model-1")
|
||||||
"model-2", provider_model_id="test_provider/model-2", provider_id="test_provider"
|
assert model.identifier == "test_provider/model-1"
|
||||||
)
|
model = await routing_table.get_model("test_provider/model-2")
|
||||||
await routing_table.register_model(
|
assert model.identifier == "test_provider/model-2"
|
||||||
"model-3", provider_model_id="test_provider/model-3", provider_id="test_provider"
|
model = await routing_table.get_model("test_provider/model-3")
|
||||||
)
|
assert model.identifier == "test_provider/model-3"
|
||||||
model = await routing_table.get_model("model-1")
|
|
||||||
assert model.identifier == "model-1"
|
|
||||||
model = await routing_table.get_model("model-2")
|
|
||||||
assert model.identifier == "model-2"
|
|
||||||
model = await routing_table.get_model("model-3")
|
|
||||||
assert model.identifier == "model-3"
|
|
||||||
|
|
||||||
mock_get_authenticated_user.return_value = User(
|
mock_get_authenticated_user.return_value = User(
|
||||||
"user-2",
|
"user-2",
|
||||||
|
|
@ -308,16 +302,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
||||||
"projects": ["foo"],
|
"projects": ["foo"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model = await routing_table.get_model("model-1")
|
model = await routing_table.get_model("test_provider/model-1")
|
||||||
assert model.identifier == "model-1"
|
assert model.identifier == "test_provider/model-1"
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await routing_table.get_model("model-2")
|
await routing_table.get_model("test_provider/model-2")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await routing_table.get_model("model-3")
|
await routing_table.get_model("test_provider/model-3")
|
||||||
with pytest.raises(AccessDeniedError):
|
with pytest.raises(AccessDeniedError):
|
||||||
await routing_table.register_model("model-4", provider_id="test_provider")
|
await routing_table.register_model("model-4", provider_id="test_provider")
|
||||||
with pytest.raises(AccessDeniedError):
|
with pytest.raises(AccessDeniedError):
|
||||||
await routing_table.unregister_model("model-1")
|
await routing_table.unregister_model("test_provider/model-1")
|
||||||
|
|
||||||
mock_get_authenticated_user.return_value = User(
|
mock_get_authenticated_user.return_value = User(
|
||||||
"user-3",
|
"user-3",
|
||||||
|
|
@ -326,16 +320,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
||||||
"projects": ["bar"],
|
"projects": ["bar"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model = await routing_table.get_model("model-2")
|
model = await routing_table.get_model("test_provider/model-2")
|
||||||
assert model.identifier == "model-2"
|
assert model.identifier == "test_provider/model-2"
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await routing_table.get_model("model-1")
|
await routing_table.get_model("test_provider/model-1")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await routing_table.get_model("model-3")
|
await routing_table.get_model("test_provider/model-3")
|
||||||
with pytest.raises(AccessDeniedError):
|
with pytest.raises(AccessDeniedError):
|
||||||
await routing_table.register_model("model-5", provider_id="test_provider")
|
await routing_table.register_model("model-5", provider_id="test_provider")
|
||||||
with pytest.raises(AccessDeniedError):
|
with pytest.raises(AccessDeniedError):
|
||||||
await routing_table.unregister_model("model-2")
|
await routing_table.unregister_model("test_provider/model-2")
|
||||||
|
|
||||||
mock_get_authenticated_user.return_value = User(
|
mock_get_authenticated_user.return_value = User(
|
||||||
"user-1",
|
"user-1",
|
||||||
|
|
@ -344,9 +338,9 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
|
||||||
"projects": ["foo", "bar"],
|
"projects": ["foo", "bar"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await routing_table.unregister_model("model-3")
|
await routing_table.unregister_model("test_provider/model-3")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await routing_table.get_model("model-3")
|
await routing_table.get_model("test_provider/model-3")
|
||||||
|
|
||||||
|
|
||||||
def test_permit_when():
|
def test_permit_when():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue