mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 16:54:42 +00:00
Account for if a permitted model is None
This commit is contained in:
parent
05f1041bfa
commit
16ffe19a20
4 changed files with 5 additions and 8 deletions
|
@ -1,9 +1,6 @@
|
||||||
name: "Run Llama-stack Tests"
|
name: "Run Llama-stack Tests"
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- 'main'
|
|
||||||
pull_request_target:
|
pull_request_target:
|
||||||
types: ["opened"]
|
types: ["opened"]
|
||||||
branches:
|
branches:
|
||||||
|
|
|
@ -37,8 +37,8 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
permitted_models = supported_inference_models()
|
permitted_models = supported_inference_models()
|
||||||
descriptors = [m.descriptor() for m in permitted_models]
|
descriptors = [m.descriptor() for m in permitted_models if m is not None]
|
||||||
repos = [m.huggingface_repo for m in permitted_models]
|
repos = [m.huggingface_repo for m in permitted_models if m is not None]
|
||||||
if model not in (descriptors + repos):
|
if model not in (descriptors + repos):
|
||||||
model_list = "\n\t".join(repos)
|
model_list = "\n\t".join(repos)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -49,8 +49,8 @@ class VLLMConfig(BaseModel):
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
permitted_models = supported_inference_models()
|
permitted_models = supported_inference_models()
|
||||||
|
|
||||||
descriptors = [m.descriptor() for m in permitted_models]
|
descriptors = [m.descriptor() for m in permitted_models if m is not None]
|
||||||
repos = [m.huggingface_repo for m in permitted_models]
|
repos = [m.huggingface_repo for m in permitted_models if m is not None]
|
||||||
if model not in (descriptors + repos):
|
if model not in (descriptors + repos):
|
||||||
model_list = "\n\t".join(repos)
|
model_list = "\n\t".join(repos)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -179,7 +179,7 @@ def chat_completion_request_to_messages(
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
||||||
allowed_models = supported_inference_models()
|
allowed_models = supported_inference_models()
|
||||||
descriptors = [m.descriptor() for m in allowed_models]
|
descriptors = [m.descriptor() for m in allowed_models if m is not None]
|
||||||
if model.descriptor() not in descriptors:
|
if model.descriptor() not in descriptors:
|
||||||
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
|
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue