diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 0f07badfa..07100c982 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -6,7 +6,6 @@ import pytest -from llama_models.datatypes import CoreModelId # How to run this test: # @@ -17,11 +16,22 @@ from llama_models.datatypes import CoreModelId class TestModelRegistration: @pytest.mark.asyncio - async def test_register_unsupported_model(self, inference_stack): - _, models_impl = inference_stack + async def test_register_unsupported_model(self, inference_stack, inference_model): + inference_impl, models_impl = inference_stack + + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::ollama", + "remote::vllm", + "remote::tgi", + ): + pytest.skip( + "Skipping test for remote inference providers since they can handle large models like 70B instruct" + ) # Try to register a model that's too large for local inference - with pytest.raises(Exception) as exc_info: + with pytest.raises(ValueError) as exc_info: await models_impl.register_model( model_id="Llama3.1-70B-Instruct", ) @@ -37,21 +47,27 @@ class TestModelRegistration: ) @pytest.mark.asyncio - async def test_update_model(self, inference_stack): + async def test_register_with_llama_model(self, inference_stack): _, models_impl = inference_stack - # Register a model to update - model_id = CoreModelId.llama3_1_8b_instruct.value - old_model = await models_impl.register_model(model_id=model_id) - - # Update the model - new_model_id = CoreModelId.llama3_2_3b_instruct.value - updated_model = await models_impl.update_model( - model_id=model_id, provider_model_id=new_model_id + _ = await models_impl.register_model( + model_id="custom-model", + metadata={"llama_model": "meta-llama/Llama-2-7b"}, ) - # Retrieve the updated model to verify changes - assert updated_model.provider_resource_id != old_model.provider_resource_id + with pytest.raises(ValueError) as exc_info: + await models_impl.register_model( + model_id="custom-model-2", + metadata={"llama_model": "meta-llama/Llama-2-7b"}, + provider_model_id="custom-model", + ) - # Cleanup - await models_impl.unregister_model(model_id=model_id) + @pytest.mark.asyncio + async def test_register_with_invalid_llama_model(self, inference_stack): + _, models_impl = inference_stack + + with pytest.raises(ValueError) as exc_info: + await models_impl.register_model( + model_id="custom-model-2", + metadata={"llama_model": "invalid-llama-model"}, + ) diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index 55f72a791..7d268ed38 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -31,3 +31,8 @@ def supported_inference_models() -> List[str]: or is_supported_safety_model(m) ) ] + + +ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = { + m.huggingface_repo: m.descriptor() for m in all_registered_models() +} diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 77eb5b415..3834946f5 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -11,6 +11,10 @@ from llama_models.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.utils.inference import ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, +) + ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) @@ -51,7 +55,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): if identifier in self.alias_to_provider_id_map: return self.alias_to_provider_id_map[identifier] else: - raise ValueError(f"Unknown model: `{identifier}`") + return None def get_llama_model(self, provider_model_id: str) -> str: if provider_model_id in self.provider_id_to_llama_model_map: @@ -60,8 +64,34 @@ class ModelRegistryHelper(ModelsProtocolPrivate): return None async def register_model(self, model: Model) -> Model: - model.provider_resource_id = self.get_provider_model_id( - model.provider_resource_id - ) + provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + if provider_resource_id: + model.provider_resource_id = provider_resource_id + else: + if model.metadata.get("llama_model") is None: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. " + "Please specify a llama_model in metadata or use a supported model identifier" + ) + existing_llama_model = self.get_llama_model(model.provider_resource_id) + if existing_llama_model: + if existing_llama_model != model.metadata["llama_model"]: + raise ValueError( + f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" + ) + else: + if ( + model.metadata["llama_model"] + not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR + ): + raise ValueError( + f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. " + f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" + ) + self.provider_id_to_llama_model_map[model.provider_resource_id] = ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[ + model.metadata["llama_model"] + ] + ) return model