Account for if a permitted model is None

This commit is contained in:
Connor Hack 2024-11-20 18:03:34 -08:00
parent 05f1041bfa
commit 16ffe19a20
4 changed files with 5 additions and 8 deletions

View file

@ -1,9 +1,6 @@
name: "Run Llama-stack Tests" name: "Run Llama-stack Tests"
on: on:
push:
branches:
- 'main'
pull_request_target: pull_request_target:
types: ["opened"] types: ["opened"]
branches: branches:

View file

@ -37,8 +37,8 @@ class MetaReferenceInferenceConfig(BaseModel):
@classmethod @classmethod
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models() permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models] descriptors = [m.descriptor() for m in permitted_models if m is not None]
repos = [m.huggingface_repo for m in permitted_models] repos = [m.huggingface_repo for m in permitted_models if m is not None]
if model not in (descriptors + repos): if model not in (descriptors + repos):
model_list = "\n\t".join(repos) model_list = "\n\t".join(repos)
raise ValueError( raise ValueError(

View file

@ -49,8 +49,8 @@ class VLLMConfig(BaseModel):
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models() permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models] descriptors = [m.descriptor() for m in permitted_models if m is not None]
repos = [m.huggingface_repo for m in permitted_models] repos = [m.huggingface_repo for m in permitted_models if m is not None]
if model not in (descriptors + repos): if model not in (descriptors + repos):
model_list = "\n\t".join(repos) model_list = "\n\t".join(repos)
raise ValueError( raise ValueError(

View file

@ -179,7 +179,7 @@ def chat_completion_request_to_messages(
return request.messages return request.messages
allowed_models = supported_inference_models() allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models] descriptors = [m.descriptor() for m in allowed_models if m is not None]
if model.descriptor() not in descriptors: if model.descriptor() not in descriptors:
cprint(f"Unsupported inference model? {model.descriptor()}", color="red") cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
return request.messages return request.messages