diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 4fa7f4e61..07100c982 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -6,8 +6,6 @@ import pytest -from llama_models.datatypes import CoreModelId - # How to run this test: # @@ -28,10 +26,12 @@ class TestModelRegistration: "remote::vllm", "remote::tgi", ): - pytest.skip("70B instruct is too big only for local inference providers") + pytest.skip( + "Skipping test for remote inference providers since they can handle large models like 70B instruct" + ) # Try to register a model that's too large for local inference - with pytest.raises(Exception) as exc_info: + with pytest.raises(ValueError) as exc_info: await models_impl.register_model( model_id="Llama3.1-70B-Instruct", ) @@ -52,13 +52,13 @@ class TestModelRegistration: _ = await models_impl.register_model( model_id="custom-model", - metadata={"llama_model": CoreModelId.llama3_1_8b_instruct.value}, + metadata={"llama_model": "meta-llama/Llama-2-7b"}, ) with pytest.raises(ValueError) as exc_info: await models_impl.register_model( model_id="custom-model-2", - metadata={"llama_model": CoreModelId.llama3_2_3b_instruct.value}, + metadata={"llama_model": "meta-llama/Llama-2-7b"}, provider_model_id="custom-model", ) @@ -66,7 +66,7 @@ class TestModelRegistration: async def test_register_with_invalid_llama_model(self, inference_stack): _, models_impl = inference_stack - with pytest.raises(Exception) as exc_info: + with pytest.raises(ValueError) as exc_info: await models_impl.register_model( model_id="custom-model-2", metadata={"llama_model": "invalid-llama-model"}, diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index 55f72a791..7d268ed38 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -31,3 +31,8 @@ def supported_inference_models() -> List[str]: or is_supported_safety_model(m) ) ] + + +ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = { + m.huggingface_repo: m.descriptor() for m in all_registered_models() +} diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 98810783d..3834946f5 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -7,11 +7,14 @@ from collections import namedtuple from typing import List, Optional -from llama_models.datatypes import CoreModelId from llama_models.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.utils.inference import ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, +) + ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) @@ -77,17 +80,18 @@ class ModelRegistryHelper(ModelsProtocolPrivate): f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" ) else: - # Validate that the llama model is a valid CoreModelId - try: - CoreModelId(model.metadata["llama_model"]) - except ValueError as err: + if ( + model.metadata["llama_model"] + not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR + ): raise ValueError( f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. " - f"Must be one of: {', '.join(m.value for m in CoreModelId)}" - ) from err - # Register the mapping from provider model id to llama model for future lookups + f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" + ) self.provider_id_to_llama_model_map[model.provider_resource_id] = ( - model.metadata["llama_model"] + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[ + model.metadata["llama_model"] + ] ) return model