diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 11648b117..4713e7f99 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -37,8 +37,10 @@ class MetaReferenceInferenceConfig(BaseModel): @classmethod def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - if model not in permitted_models: - model_list = "\n\t".join(permitted_models) + descriptors = [m.descriptor() for m in permitted_models] + repos = [m.huggingface_repo for m in permitted_models] + if model not in (descriptors + repos): + model_list = "\n\t".join(repos) raise ValueError( f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" ) diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index e5516673c..8a95298f4 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -48,8 +48,11 @@ class VLLMConfig(BaseModel): @classmethod def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - if model not in permitted_models: - model_list = "\n\t".join(permitted_models) + + descriptors = [m.descriptor() for m in permitted_models] + repos = [m.huggingface_repo for m in permitted_models] + if model not in (descriptors + repos): + model_list = "\n\t".join(repos) raise ValueError( f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" ) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 07100c982..1471bc369 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -11,7 +11,6 @@ import pytest # # pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py # -m "meta_reference" -# --env TOGETHER_API_KEY= class TestModelRegistration: diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index 7d268ed38..d204f98a4 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -22,9 +22,9 @@ def is_supported_safety_model(model: Model) -> bool: ] -def supported_inference_models() -> List[str]: +def supported_inference_models() -> List[Model]: return [ - m.descriptor() + m for m in all_registered_models() if ( m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2} diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 2df04664f..6e4d0752e 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -178,7 +178,9 @@ def chat_completion_request_to_messages( cprint(f"Could not resolve model {llama_model}", color="red") return request.messages - if model.descriptor() not in supported_inference_models(): + allowed_models = supported_inference_models() + descriptors = [m.descriptor() for m in allowed_models] + if model.descriptor() not in descriptors: cprint(f"Unsupported inference model? {model.descriptor()}", color="red") return request.messages