fix(models)!: always prefix models with provider_id when registering (#3822)

**!!BREAKING CHANGE!!**

The lookup is also straightforward -- we always look for this identifier
and don't try to find a match for something without the provider_id
prefix.

Note that, this ideally means we need to update the `register_model()`
API also (we should kill "identifier" from there) but I am not doing
that as part of this PR.

## Test Plan

Existing unit tests
This commit is contained in:
Ashwin Bharambe 2025-10-16 06:47:39 -07:00 committed by GitHub
parent f205ab6f6c
commit f70aa99c97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 53 additions and 124 deletions

View file

@ -256,12 +256,12 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
- permit:
principal: user-2
actions: [read]
resource: model::model-1
resource: model::test_provider/model-1
description: user-2 has read access to model-1 only
- permit:
principal: user-3
actions: [read]
resource: model::model-2
resource: model::test_provider/model-2
description: user-3 has read access to model-2 only
- forbid:
actions: [create, read, delete]
@ -285,21 +285,15 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo", "bar"],
},
)
await routing_table.register_model(
"model-1", provider_model_id="test_provider/model-1", provider_id="test_provider"
)
await routing_table.register_model(
"model-2", provider_model_id="test_provider/model-2", provider_id="test_provider"
)
await routing_table.register_model(
"model-3", provider_model_id="test_provider/model-3", provider_id="test_provider"
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
await routing_table.register_model("model-1", provider_model_id="model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_model_id="model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_model_id="model-3", provider_id="test_provider")
model = await routing_table.get_model("test_provider/model-1")
assert model.identifier == "test_provider/model-1"
model = await routing_table.get_model("test_provider/model-2")
assert model.identifier == "test_provider/model-2"
model = await routing_table.get_model("test_provider/model-3")
assert model.identifier == "test_provider/model-3"
mock_get_authenticated_user.return_value = User(
"user-2",
@ -308,16 +302,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo"],
},
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("test_provider/model-1")
assert model.identifier == "test_provider/model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
await routing_table.get_model("test_provider/model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
await routing_table.get_model("test_provider/model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1")
await routing_table.unregister_model("test_provider/model-1")
mock_get_authenticated_user.return_value = User(
"user-3",
@ -326,16 +320,16 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["bar"],
},
)
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("test_provider/model-2")
assert model.identifier == "test_provider/model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
await routing_table.get_model("test_provider/model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
await routing_table.get_model("test_provider/model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2")
await routing_table.unregister_model("test_provider/model-2")
mock_get_authenticated_user.return_value = User(
"user-1",
@ -344,9 +338,9 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo", "bar"],
},
)
await routing_table.unregister_model("model-3")
await routing_table.unregister_model("test_provider/model-3")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
await routing_table.get_model("test_provider/model-3")
def test_permit_when():