mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
add validation for configuration input
This commit is contained in:
parent
ab856c174c
commit
9e3182216d
3 changed files with 124 additions and 48 deletions
|
@ -6,9 +6,12 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.datatypes import ModelFamily
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from llama_toolchain.inference.api import QuantizationConfig
|
||||
|
||||
|
@ -20,3 +23,18 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
|
||||
@validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family == ModelFamily.llama3_1
|
||||
]
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue