access control unit test

This commit is contained in:
Ashwin Bharambe 2025-10-15 22:00:47 -07:00
parent ea0d342c5d
commit 54354257dc
2 changed files with 25 additions and 30 deletions

View file

@ -435,7 +435,8 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
# First check if the model is pre-registered in the model store # First check if the model is pre-registered in the model store
if hasattr(self, "model_store") and self.model_store: if hasattr(self, "model_store") and self.model_store:
if await self.model_store.has_model(model): qualified_model = f"{self.__provider_id__}/{model}"
if await self.model_store.has_model(qualified_model):
return True return True
# Then check the provider's dynamic model cache # Then check the provider's dynamic model cache

View file

@ -256,12 +256,12 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
- permit: - permit:
principal: user-2 principal: user-2
actions: [read] actions: [read]
resource: model::model-1 resource: model::test_provider/model-1
description: user-2 has read access to model-1 only description: user-2 has read access to model-1 only
- permit: - permit:
principal: user-3 principal: user-3
actions: [read] actions: [read]
resource: model::model-2 resource: model::test_provider/model-2
description: user-3 has read access to model-2 only description: user-3 has read access to model-2 only
- forbid: - forbid:
actions: [create, read, delete] actions: [create, read, delete]
@ -285,21 +285,15 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo", "bar"], "projects": ["foo", "bar"],
}, },
) )
await routing_table.register_model( await routing_table.register_model("model-1", provider_model_id="model-1", provider_id="test_provider")
"model-1", provider_model_id="test_provider/model-1", provider_id="test_provider" await routing_table.register_model("model-2", provider_model_id="model-2", provider_id="test_provider")
) await routing_table.register_model("model-3", provider_model_id="model-3", provider_id="test_provider")
await routing_table.register_model( model = await routing_table.get_model("test_provider/model-1")
"model-2", provider_model_id="test_provider/model-2", provider_id="test_provider" assert model.identifier == "test_provider/model-1"
) model = await routing_table.get_model("test_provider/model-2")
await routing_table.register_model( assert model.identifier == "test_provider/model-2"
"model-3", provider_model_id="test_provider/model-3", provider_id="test_provider" model = await routing_table.get_model("test_provider/model-3")
) assert model.identifier == "test_provider/model-3"
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
mock_get_authenticated_user.return_value = User( mock_get_authenticated_user.return_value = User(
"user-2", "user-2",
@ -308,16 +302,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo"], "projects": ["foo"],
}, },
) )
model = await routing_table.get_model("model-1") model = await routing_table.get_model("test_provider/model-1")
assert model.identifier == "model-1" assert model.identifier == "test_provider/model-1"
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-2") await routing_table.get_model("test_provider/model-2")
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-3") await routing_table.get_model("test_provider/model-3")
with pytest.raises(AccessDeniedError): with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider") await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError): with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1") await routing_table.unregister_model("test_provider/model-1")
mock_get_authenticated_user.return_value = User( mock_get_authenticated_user.return_value = User(
"user-3", "user-3",
@ -326,16 +320,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["bar"], "projects": ["bar"],
}, },
) )
model = await routing_table.get_model("model-2") model = await routing_table.get_model("test_provider/model-2")
assert model.identifier == "model-2" assert model.identifier == "test_provider/model-2"
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-1") await routing_table.get_model("test_provider/model-1")
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-3") await routing_table.get_model("test_provider/model-3")
with pytest.raises(AccessDeniedError): with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider") await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError): with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2") await routing_table.unregister_model("test_provider/model-2")
mock_get_authenticated_user.return_value = User( mock_get_authenticated_user.return_value = User(
"user-1", "user-1",
@ -344,9 +338,9 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo", "bar"], "projects": ["foo", "bar"],
}, },
) )
await routing_table.unregister_model("model-3") await routing_table.unregister_model("test_provider/model-3")
with pytest.raises(ValueError): with pytest.raises(ValueError):
await routing_table.get_model("model-3") await routing_table.get_model("test_provider/model-3")
def test_permit_when(): def test_permit_when():