diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 564e5a708..4713e7f99 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -37,8 +37,8 @@ class MetaReferenceInferenceConfig(BaseModel): @classmethod def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - descriptors = [m.descriptor() for m in permitted_models if m is not None] - repos = [m.huggingface_repo for m in permitted_models if m is not None] + descriptors = [m.descriptor() for m in permitted_models] + repos = [m.huggingface_repo for m in permitted_models] if model not in (descriptors + repos): model_list = "\n\t".join(repos) raise ValueError( diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e6bcd6730..d58ecc8bd 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -34,6 +34,8 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config model = resolve_model(config.model) + if model is None: + raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") ModelRegistryHelper.__init__( self, [ @@ -43,8 +45,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP ) ], ) - if model is None: - raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") self.model = model # verify that the checkpoint actually is for this model lol diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index 2a39d0096..8a95298f4 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -49,8 +49,8 @@ class VLLMConfig(BaseModel): def validate_model(cls, model: str) -> str: permitted_models = supported_inference_models() - descriptors = [m.descriptor() for m in permitted_models if m is not None] - repos = [m.huggingface_repo for m in permitted_models if m is not None] + descriptors = [m.descriptor() for m in permitted_models] + repos = [m.huggingface_repo for m in permitted_models] if model not in (descriptors + repos): model_list = "\n\t".join(repos) raise ValueError( diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 5d81bb4b1..6e4d0752e 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -179,7 +179,7 @@ def chat_completion_request_to_messages( return request.messages allowed_models = supported_inference_models() - descriptors = [m.descriptor() for m in allowed_models if m is not None] + descriptors = [m.descriptor() for m in allowed_models] if model.descriptor() not in descriptors: cprint(f"Unsupported inference model? {model.descriptor()}", color="red") return request.messages