From 490c5fb73071e20e412a14985f42b4de583be005 Mon Sep 17 00:00:00 2001 From: Connor Hack Date: Wed, 20 Nov 2024 19:17:44 -0800 Subject: [PATCH] Undo None check and temporarily move if model check before builder --- .../providers/inline/inference/meta_reference/config.py | 4 ++-- .../providers/inline/inference/meta_reference/inference.py | 4 ++-- llama_stack/providers/inline/inference/vllm/config.py | 4 ++-- llama_stack/providers/utils/inference/prompt_adapter.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 564e5a708..4713e7f99 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -37,8 +37,8 @@ class MetaReferenceInferenceConfig(BaseModel): @classmethod def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - descriptors = [m.descriptor() for m in permitted_models if m is not None] - repos = [m.huggingface_repo for m in permitted_models if m is not None] + descriptors = [m.descriptor() for m in permitted_models] + repos = [m.huggingface_repo for m in permitted_models] if model not in (descriptors + repos): model_list = "\n\t".join(repos) raise ValueError( diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e6bcd6730..d58ecc8bd 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -34,6 +34,8 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config model = resolve_model(config.model) + if model is None: + raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") ModelRegistryHelper.__init__( self, [ @@ -43,8 +45,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP ) ], ) - if model is None: - raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") self.model = model # verify that the checkpoint actually is for this model lol diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index 2a39d0096..8a95298f4 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -49,8 +49,8 @@ class VLLMConfig(BaseModel): def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - descriptors = [m.descriptor() for m in permitted_models if m is not None] - repos = [m.huggingface_repo for m in permitted_models if m is not None] + descriptors = [m.descriptor() for m in permitted_models] + repos = [m.huggingface_repo for m in permitted_models] if model not in (descriptors + repos): model_list = "\n\t".join(repos) raise ValueError( diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 5d81bb4b1..6e4d0752e 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -179,7 +179,7 @@ def chat_completion_request_to_messages( return request.messages allowed_models = supported_inference_models() - descriptors = [m.descriptor() for m in allowed_models if m is not None] + descriptors = [m.descriptor() for m in allowed_models] if model.descriptor() not in descriptors: cprint(f"Unsupported inference model? {model.descriptor()}", color="red") return request.messages