diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py index 6757b1cfa..d9aef32e6 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -11,14 +11,17 @@ from llama_models.datatypes import ModelFamily from llama_models.schema_utils import json_schema_type from llama_models.sku_list import all_registered_models -from pydantic import BaseModel, validator - from llama_toolchain.inference.api import QuantizationConfig +from pydantic import BaseModel, Field, validator + @json_schema_type class MetaReferenceImplConfig(BaseModel): - model: str + model: str = Field( + default="Meta-Llama3.1-8B-Instruct", + description="Model descriptor from `llama model list`", + ) quantization: Optional[QuantizationConfig] = None torch_seed: Optional[int] = None max_seq_len: int