From 16ffe19a2048345bbee50d5d8ac8ebd1a091310a Mon Sep 17 00:00:00 2001 From: Connor Hack Date: Wed, 20 Nov 2024 18:03:34 -0800 Subject: [PATCH] Account for if a permitted model is None --- .github/workflows/gha_workflow_llama_stack_tests.yml | 3 --- .../providers/inline/inference/meta_reference/config.py | 4 ++-- llama_stack/providers/inline/inference/vllm/config.py | 4 ++-- llama_stack/providers/utils/inference/prompt_adapter.py | 2 +- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/.github/workflows/gha_workflow_llama_stack_tests.yml b/.github/workflows/gha_workflow_llama_stack_tests.yml index ee3451d7b..8f7a25ee4 100644 --- a/.github/workflows/gha_workflow_llama_stack_tests.yml +++ b/.github/workflows/gha_workflow_llama_stack_tests.yml @@ -1,9 +1,6 @@ name: "Run Llama-stack Tests" on: - push: - branches: - - 'main' pull_request_target: types: ["opened"] branches: diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 4713e7f99..564e5a708 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -37,8 +37,8 @@ class MetaReferenceInferenceConfig(BaseModel): @classmethod def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - descriptors = [m.descriptor() for m in permitted_models] - repos = [m.huggingface_repo for m in permitted_models] + descriptors = [m.descriptor() for m in permitted_models if m is not None] + repos = [m.huggingface_repo for m in permitted_models if m is not None] if model not in (descriptors + repos): model_list = "\n\t".join(repos) raise ValueError( diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index 8a95298f4..2a39d0096 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -49,8 +49,8 @@ class VLLMConfig(BaseModel): def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - descriptors = [m.descriptor() for m in permitted_models] - repos = [m.huggingface_repo for m in permitted_models] + descriptors = [m.descriptor() for m in permitted_models if m is not None] + repos = [m.huggingface_repo for m in permitted_models if m is not None] if model not in (descriptors + repos): model_list = "\n\t".join(repos) raise ValueError( diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 6e4d0752e..5d81bb4b1 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -179,7 +179,7 @@ def chat_completion_request_to_messages( return request.messages allowed_models = supported_inference_models() - descriptors = [m.descriptor() for m in allowed_models] + descriptors = [m.descriptor() for m in allowed_models if m is not None] if model.descriptor() not in descriptors: cprint(f"Unsupported inference model? {model.descriptor()}", color="red") return request.messages