mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 23:08:47 +00:00
flatten alias to provider map, test registering to existing alias
This commit is contained in:
parent
9982aa64f0
commit
84351c2d67
2 changed files with 79 additions and 49 deletions
|
|
@ -11,11 +11,14 @@
|
|||
#
|
||||
# Test cases -
|
||||
# - Looking up an alias that does not exist should return None.
|
||||
# - Registering a model + provider ID should add the model to the registry. If provider ID is known.
|
||||
# - Registering an existing model should return an error.
|
||||
# - Registering a model + provider ID should add the model to the registry. If
|
||||
# provider ID is known or an alias for a provider ID.
|
||||
# - Registering an existing model should return an error. Unless it's a
|
||||
# dulicate entry.
|
||||
# - Unregistering a model should remove it from the registry.
|
||||
# - Unregistering a model that does not exist should return an error.
|
||||
# - Models can be registered during initialization or via register_model.
|
||||
# - Supported model ID and their aliases are registered during initialization.
|
||||
# Only aliases are added afterwards.
|
||||
#
|
||||
# Questions -
|
||||
# - Should we be allowed to register models w/o provider model IDs? No.
|
||||
|
|
@ -45,10 +48,27 @@ def known_model() -> Model:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_model2() -> Model:
|
||||
return Model(
|
||||
provider_id="provider",
|
||||
identifier="known-model2",
|
||||
provider_resource_id="known-provider-id2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_provider_model(known_model: Model) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=known_model.provider_resource_id,
|
||||
aliases=[known_model.model_id],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_provider_model2(known_model2: Model) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=known_model2.provider_resource_id,
|
||||
# aliases=[],
|
||||
)
|
||||
|
||||
|
|
@ -63,8 +83,8 @@ def unknown_model() -> Model:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def helper(known_provider_model: ProviderModelEntry) -> ModelRegistryHelper:
|
||||
return ModelRegistryHelper([known_provider_model])
|
||||
def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper:
|
||||
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -80,21 +100,46 @@ async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unkn
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||
await helper.register_model(known_model)
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
||||
model = Model(
|
||||
provider_id=known_model.provider_id,
|
||||
identifier="new-model",
|
||||
provider_resource_id=known_model.provider_resource_id,
|
||||
)
|
||||
assert helper.get_provider_model_id(model.model_id) is None
|
||||
await helper.register_model(model)
|
||||
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
model = Model(
|
||||
provider_id=known_model.provider_id,
|
||||
identifier="new-model",
|
||||
provider_resource_id=known_model.model_id, # use known model's id as an alias for the supported model id
|
||||
)
|
||||
assert helper.get_provider_model_id(model.model_id) is None
|
||||
await helper.register_model(model)
|
||||
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
await helper.register_model(known_model)
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_existing_different(
|
||||
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
|
||||
) -> None:
|
||||
known_model.provider_resource_id = known_model2.provider_resource_id
|
||||
with pytest.raises(ValueError):
|
||||
await helper.register_model(known_model)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
await helper.register_model(known_model)
|
||||
await helper.register_model(known_model) # duplicate entry
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.model_id)
|
||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||
|
|
@ -111,14 +156,6 @@ async def test_register_model_during_init(helper: ModelRegistryHelper, known_mod
|
|||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_existing_from_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
with pytest.raises(ValueError):
|
||||
known_model.identifier = known_model.provider_resource_id
|
||||
await helper.register_model(known_model)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue