mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
flatten alias to provider map, test registering to existing alias
This commit is contained in:
parent
9982aa64f0
commit
84351c2d67
2 changed files with 79 additions and 49 deletions
|
@ -59,8 +59,6 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
|
||||||
|
|
||||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
def __init__(self, model_entries: List[ProviderModelEntry]):
|
def __init__(self, model_entries: List[ProviderModelEntry]):
|
||||||
self.supported_model_ids = {entry.provider_model_id for entry in model_entries}
|
|
||||||
|
|
||||||
self.alias_to_provider_id_map = {}
|
self.alias_to_provider_id_map = {}
|
||||||
self.provider_id_to_llama_model_map = {}
|
self.provider_id_to_llama_model_map = {}
|
||||||
for entry in model_entries:
|
for entry in model_entries:
|
||||||
|
@ -77,54 +75,49 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
def get_provider_model_id(self, identifier: str) -> Optional[str]:
|
def get_provider_model_id(self, identifier: str) -> Optional[str]:
|
||||||
return self.alias_to_provider_id_map.get(identifier, None)
|
return self.alias_to_provider_id_map.get(identifier, None)
|
||||||
|
|
||||||
|
# TODO: why keep a separate llama model mapping?
|
||||||
def get_llama_model(self, provider_model_id: str) -> Optional[str]:
|
def get_llama_model(self, provider_model_id: str) -> Optional[str]:
|
||||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
if model.provider_resource_id not in self.supported_model_ids:
|
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model id '{model.provider_resource_id}' is not supported. Supported ids are: {', '.join(self.supported_model_ids)}"
|
f"Model '{model.provider_resource_id}' is not supported. Supported models are: {', '.join(self.alias_to_provider_id_map.keys())}"
|
||||||
)
|
)
|
||||||
if model.model_id in self.alias_to_provider_id_map:
|
provider_resource_id = self.get_provider_model_id(model.model_id)
|
||||||
# be idemopotent
|
|
||||||
if model.provider_resource_id != self.alias_to_provider_id_map[model.model_id]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
|
|
||||||
)
|
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
provider_resource_id = model.provider_resource_id
|
provider_resource_id = model.provider_resource_id
|
||||||
else:
|
|
||||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
|
||||||
|
|
||||||
if provider_resource_id:
|
if provider_resource_id:
|
||||||
model.provider_resource_id = provider_resource_id
|
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
|
||||||
|
raise ValueError(
|
||||||
|
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
llama_model = model.metadata.get("llama_model")
|
llama_model = model.metadata.get("llama_model")
|
||||||
if llama_model is None:
|
if llama_model:
|
||||||
return model
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
|
if existing_llama_model:
|
||||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
if existing_llama_model != llama_model:
|
||||||
if existing_llama_model:
|
raise ValueError(
|
||||||
if existing_llama_model != llama_model:
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
raise ValueError(
|
)
|
||||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
else:
|
||||||
|
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||||
|
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||||
|
)
|
||||||
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
|
||||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
|
||||||
)
|
|
||||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.alias_to_provider_id_map[model.model_id] = model.provider_resource_id
|
self.alias_to_provider_id_map[model.model_id] = supported_model_id
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
# TODO: should we block unregistering base supported provider model IDs?
|
||||||
if model_id not in self.alias_to_provider_id_map:
|
if model_id not in self.alias_to_provider_id_map:
|
||||||
raise ValueError(f"Model id '{model_id}' is not registered.")
|
raise ValueError(f"Model id '{model_id}' is not registered.")
|
||||||
|
|
||||||
|
|
|
@ -11,11 +11,14 @@
|
||||||
#
|
#
|
||||||
# Test cases -
|
# Test cases -
|
||||||
# - Looking up an alias that does not exist should return None.
|
# - Looking up an alias that does not exist should return None.
|
||||||
# - Registering a model + provider ID should add the model to the registry. If provider ID is known.
|
# - Registering a model + provider ID should add the model to the registry. If
|
||||||
# - Registering an existing model should return an error.
|
# provider ID is known or an alias for a provider ID.
|
||||||
|
# - Registering an existing model should return an error. Unless it's a
|
||||||
|
# dulicate entry.
|
||||||
# - Unregistering a model should remove it from the registry.
|
# - Unregistering a model should remove it from the registry.
|
||||||
# - Unregistering a model that does not exist should return an error.
|
# - Unregistering a model that does not exist should return an error.
|
||||||
# - Models can be registered during initialization or via register_model.
|
# - Supported model ID and their aliases are registered during initialization.
|
||||||
|
# Only aliases are added afterwards.
|
||||||
#
|
#
|
||||||
# Questions -
|
# Questions -
|
||||||
# - Should we be allowed to register models w/o provider model IDs? No.
|
# - Should we be allowed to register models w/o provider model IDs? No.
|
||||||
|
@ -45,10 +48,27 @@ def known_model() -> Model:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def known_model2() -> Model:
|
||||||
|
return Model(
|
||||||
|
provider_id="provider",
|
||||||
|
identifier="known-model2",
|
||||||
|
provider_resource_id="known-provider-id2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def known_provider_model(known_model: Model) -> ProviderModelEntry:
|
def known_provider_model(known_model: Model) -> ProviderModelEntry:
|
||||||
return ProviderModelEntry(
|
return ProviderModelEntry(
|
||||||
provider_model_id=known_model.provider_resource_id,
|
provider_model_id=known_model.provider_resource_id,
|
||||||
|
aliases=[known_model.model_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def known_provider_model2(known_model2: Model) -> ProviderModelEntry:
|
||||||
|
return ProviderModelEntry(
|
||||||
|
provider_model_id=known_model2.provider_resource_id,
|
||||||
# aliases=[],
|
# aliases=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -63,8 +83,8 @@ def unknown_model() -> Model:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def helper(known_provider_model: ProviderModelEntry) -> ModelRegistryHelper:
|
def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper:
|
||||||
return ModelRegistryHelper([known_provider_model])
|
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -80,21 +100,46 @@ async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unkn
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
model = Model(
|
||||||
await helper.register_model(known_model)
|
provider_id=known_model.provider_id,
|
||||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
identifier="new-model",
|
||||||
|
provider_resource_id=known_model.provider_resource_id,
|
||||||
|
)
|
||||||
|
assert helper.get_provider_model_id(model.model_id) is None
|
||||||
|
await helper.register_model(model)
|
||||||
|
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
|
model = Model(
|
||||||
|
provider_id=known_model.provider_id,
|
||||||
|
identifier="new-model",
|
||||||
|
provider_resource_id=known_model.model_id, # use known model's id as an alias for the supported model id
|
||||||
|
)
|
||||||
|
assert helper.get_provider_model_id(model.model_id) is None
|
||||||
|
await helper.register_model(model)
|
||||||
|
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
await helper.register_model(known_model)
|
await helper.register_model(known_model)
|
||||||
|
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_model_existing_different(
|
||||||
|
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
|
||||||
|
) -> None:
|
||||||
|
known_model.provider_resource_id = known_model2.provider_resource_id
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await helper.register_model(known_model)
|
await helper.register_model(known_model)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
await helper.register_model(known_model)
|
await helper.register_model(known_model) # duplicate entry
|
||||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||||
await helper.unregister_model(known_model.model_id)
|
await helper.unregister_model(known_model.model_id)
|
||||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||||
|
@ -111,14 +156,6 @@ async def test_register_model_during_init(helper: ModelRegistryHelper, known_mod
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_model_existing_from_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
known_model.identifier = known_model.provider_resource_id
|
|
||||||
await helper.register_model(known_model)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue