From 84351c2d67717a4973ff1770e6d890eec29bd9ea Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 22 Apr 2025 09:33:06 -0400 Subject: [PATCH] flatten alias to provider map, test registering to existing alias --- .../utils/inference/model_registry.py | 57 +++++++-------- .../providers/utils/test_model_registry.py | 71 ++++++++++++++----- 2 files changed, 79 insertions(+), 49 deletions(-) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index f1f1a324a..c5199b0a8 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -59,8 +59,6 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider class ModelRegistryHelper(ModelsProtocolPrivate): def __init__(self, model_entries: List[ProviderModelEntry]): - self.supported_model_ids = {entry.provider_model_id for entry in model_entries} - self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} for entry in model_entries: @@ -77,54 +75,49 @@ class ModelRegistryHelper(ModelsProtocolPrivate): def get_provider_model_id(self, identifier: str) -> Optional[str]: return self.alias_to_provider_id_map.get(identifier, None) + # TODO: why keep a separate llama model mapping? def get_llama_model(self, provider_model_id: str) -> Optional[str]: return self.provider_id_to_llama_model_map.get(provider_model_id, None) async def register_model(self, model: Model) -> Model: - if model.provider_resource_id not in self.supported_model_ids: + if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)): raise ValueError( - f"Model id '{model.provider_resource_id}' is not supported. Supported ids are: {', '.join(self.supported_model_ids)}" + f"Model '{model.provider_resource_id}' is not supported. Supported models are: {', '.join(self.alias_to_provider_id_map.keys())}" ) - if model.model_id in self.alias_to_provider_id_map: - # be idemopotent - if model.provider_resource_id != self.alias_to_provider_id_map[model.model_id]: - raise ValueError( - f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first." - ) + provider_resource_id = self.get_provider_model_id(model.model_id) if model.model_type == ModelType.embedding: # embedding models are always registered by their provider model id and does not need to be mapped to a llama model provider_resource_id = model.provider_resource_id - else: - provider_resource_id = self.get_provider_model_id(model.provider_resource_id) - if provider_resource_id: - model.provider_resource_id = provider_resource_id + if provider_resource_id != supported_model_id: # be idemopotent, only reject differences + raise ValueError( + f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first." + ) else: llama_model = model.metadata.get("llama_model") - if llama_model is None: - return model - - existing_llama_model = self.get_llama_model(model.provider_resource_id) - if existing_llama_model: - if existing_llama_model != llama_model: - raise ValueError( - f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" + if llama_model: + existing_llama_model = self.get_llama_model(model.provider_resource_id) + if existing_llama_model: + if existing_llama_model != llama_model: + raise ValueError( + f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" + ) + else: + if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: + raise ValueError( + f"Invalid llama_model '{llama_model}' specified in metadata. " + f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" + ) + self.provider_id_to_llama_model_map[model.provider_resource_id] = ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] ) - else: - if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: - raise ValueError( - f"Invalid llama_model '{llama_model}' specified in metadata. " - f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" - ) - self.provider_id_to_llama_model_map[model.provider_resource_id] = ( - ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] - ) - self.alias_to_provider_id_map[model.model_id] = model.provider_resource_id + self.alias_to_provider_id_map[model.model_id] = supported_model_id return model async def unregister_model(self, model_id: str) -> None: + # TODO: should we block unregistering base supported provider model IDs? if model_id not in self.alias_to_provider_id_map: raise ValueError(f"Model id '{model_id}' is not registered.") diff --git a/tests/unit/providers/utils/test_model_registry.py b/tests/unit/providers/utils/test_model_registry.py index d007e3a40..67f8a138f 100644 --- a/tests/unit/providers/utils/test_model_registry.py +++ b/tests/unit/providers/utils/test_model_registry.py @@ -11,11 +11,14 @@ # # Test cases - # - Looking up an alias that does not exist should return None. -# - Registering a model + provider ID should add the model to the registry. If provider ID is known. -# - Registering an existing model should return an error. +# - Registering a model + provider ID should add the model to the registry. If +# provider ID is known or an alias for a provider ID. +# - Registering an existing model should return an error. Unless it's a +# dulicate entry. # - Unregistering a model should remove it from the registry. # - Unregistering a model that does not exist should return an error. -# - Models can be registered during initialization or via register_model. +# - Supported model ID and their aliases are registered during initialization. +# Only aliases are added afterwards. # # Questions - # - Should we be allowed to register models w/o provider model IDs? No. @@ -45,10 +48,27 @@ def known_model() -> Model: ) +@pytest.fixture +def known_model2() -> Model: + return Model( + provider_id="provider", + identifier="known-model2", + provider_resource_id="known-provider-id2", + ) + + @pytest.fixture def known_provider_model(known_model: Model) -> ProviderModelEntry: return ProviderModelEntry( provider_model_id=known_model.provider_resource_id, + aliases=[known_model.model_id], + ) + + +@pytest.fixture +def known_provider_model2(known_model2: Model) -> ProviderModelEntry: + return ProviderModelEntry( + provider_model_id=known_model2.provider_resource_id, # aliases=[], ) @@ -63,8 +83,8 @@ def unknown_model() -> Model: @pytest.fixture -def helper(known_provider_model: ProviderModelEntry) -> ModelRegistryHelper: - return ModelRegistryHelper([known_provider_model]) +def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper: + return ModelRegistryHelper([known_provider_model, known_provider_model2]) @pytest.mark.asyncio @@ -80,21 +100,46 @@ async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unkn @pytest.mark.asyncio async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None: - assert helper.get_provider_model_id(known_model.model_id) is None - await helper.register_model(known_model) - assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id + model = Model( + provider_id=known_model.provider_id, + identifier="new-model", + provider_resource_id=known_model.provider_resource_id, + ) + assert helper.get_provider_model_id(model.model_id) is None + await helper.register_model(model) + assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id + + +@pytest.mark.asyncio +async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None: + model = Model( + provider_id=known_model.provider_id, + identifier="new-model", + provider_resource_id=known_model.model_id, # use known model's id as an alias for the supported model id + ) + assert helper.get_provider_model_id(model.model_id) is None + await helper.register_model(model) + assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id @pytest.mark.asyncio async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None: await helper.register_model(known_model) + assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id + + +@pytest.mark.asyncio +async def test_register_model_existing_different( + helper: ModelRegistryHelper, known_model: Model, known_model2: Model +) -> None: + known_model.provider_resource_id = known_model2.provider_resource_id with pytest.raises(ValueError): await helper.register_model(known_model) @pytest.mark.asyncio async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None: - await helper.register_model(known_model) + await helper.register_model(known_model) # duplicate entry assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id await helper.unregister_model(known_model.model_id) assert helper.get_provider_model_id(known_model.model_id) is None @@ -111,14 +156,6 @@ async def test_register_model_during_init(helper: ModelRegistryHelper, known_mod assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id -@pytest.mark.asyncio -async def test_register_model_existing_from_init(helper: ModelRegistryHelper, known_model: Model) -> None: - assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id - with pytest.raises(ValueError): - known_model.identifier = known_model.provider_resource_id - await helper.register_model(known_model) - - @pytest.mark.asyncio async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id