From e84d4436b51260b2ad42cea2df5eeccc4f6fe9b6 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 20 Nov 2024 16:14:37 -0800 Subject: [PATCH] Since we are pushing for HF repos, we should accept them in inference configs (#497) # What does this PR do? As the title says. ## Test Plan This needs https://github.com/meta-llama/llama-models/commit/8752149f58654c54c012209f43b57bb476146f0c to also land. So the next package (0.0.54) will make this work properly. The test is: ```bash pytest -v -s -m "llama_3b and meta_reference" test_model_registration.py ``` --- .../providers/inline/inference/meta_reference/config.py | 6 ++++-- llama_stack/providers/inline/inference/vllm/config.py | 7 +++++-- .../providers/tests/inference/test_model_registration.py | 1 - llama_stack/providers/utils/inference/__init__.py | 4 ++-- llama_stack/providers/utils/inference/prompt_adapter.py | 4 +++- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 11648b117..4713e7f99 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -37,8 +37,10 @@ class MetaReferenceInferenceConfig(BaseModel): @classmethod def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - if model not in permitted_models: - model_list = "\n\t".join(permitted_models) + descriptors = [m.descriptor() for m in permitted_models] + repos = [m.huggingface_repo for m in permitted_models] + if model not in (descriptors + repos): + model_list = "\n\t".join(repos) raise ValueError( f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" ) diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index e5516673c..8a95298f4 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -48,8 +48,11 @@ class VLLMConfig(BaseModel): @classmethod def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - if model not in permitted_models: - model_list = "\n\t".join(permitted_models) + + descriptors = [m.descriptor() for m in permitted_models] + repos = [m.huggingface_repo for m in permitted_models] + if model not in (descriptors + repos): + model_list = "\n\t".join(repos) raise ValueError( f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" ) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 07100c982..1471bc369 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -11,7 +11,6 @@ import pytest # # pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py # -m "meta_reference" -# --env TOGETHER_API_KEY= class TestModelRegistration: diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index 7d268ed38..d204f98a4 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -22,9 +22,9 @@ def is_supported_safety_model(model: Model) -> bool: ] -def supported_inference_models() -> List[str]: +def supported_inference_models() -> List[Model]: return [ - m.descriptor() + m for m in all_registered_models() if ( m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2} diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 2df04664f..6e4d0752e 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -178,7 +178,9 @@ def chat_completion_request_to_messages( cprint(f"Could not resolve model {llama_model}", color="red") return request.messages - if model.descriptor() not in supported_inference_models(): + allowed_models = supported_inference_models() + descriptors = [m.descriptor() for m in allowed_models] + if model.descriptor() not in descriptors: cprint(f"Unsupported inference model? {model.descriptor()}", color="red") return request.messages