add validation for configuration input

This commit is contained in:
Ashwin Bharambe 2024-08-08 10:04:39 -07:00
parent ab856c174c
commit 9e3182216d
3 changed files with 124 additions and 48 deletions

View file

@ -6,9 +6,12 @@
from typing import Optional
from llama_models.schema_utils import json_schema_type
from llama_models.datatypes import ModelFamily
from pydantic import BaseModel
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models
from pydantic import BaseModel, validator
from llama_toolchain.inference.api import QuantizationConfig
@ -20,3 +23,18 @@ class MetaReferenceImplConfig(BaseModel):
torch_seed: Optional[int] = None
max_seq_len: int
max_batch_size: int = 1
@validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in all_registered_models()
if m.model_family == ModelFamily.llama3_1
]
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model