mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
address comments
This commit is contained in:
parent
3378c100f6
commit
aad0dedc85
1 changed files with 13 additions and 5 deletions
|
@ -43,6 +43,10 @@ class ModelConfigs(BaseModel):
|
||||||
Llama_3_8B_Instruct: ModelConfig
|
Llama_3_8B_Instruct: ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetSchema(BaseModel):
|
||||||
|
alpaca: List[Dict[str, ParamType]]
|
||||||
|
|
||||||
|
|
||||||
MODEL_CONFIGS = ModelConfigs(
|
MODEL_CONFIGS = ModelConfigs(
|
||||||
Llama3_2_3B_Instruct=ModelConfig(
|
Llama3_2_3B_Instruct=ModelConfig(
|
||||||
model_definition=lora_llama3_2_3b,
|
model_definition=lora_llama3_2_3b,
|
||||||
|
@ -56,8 +60,9 @@ MODEL_CONFIGS = ModelConfigs(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
|
|
||||||
"alpaca": [
|
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||||
|
alpaca=[
|
||||||
{
|
{
|
||||||
ColumnName.instruction.value: StringType(),
|
ColumnName.instruction.value: StringType(),
|
||||||
ColumnName.input.value: StringType(),
|
ColumnName.input.value: StringType(),
|
||||||
|
@ -74,7 +79,7 @@ EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
|
||||||
ColumnName.output.value: StringType(),
|
ColumnName.output.value: StringType(),
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
|
|
||||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||||
|
@ -138,7 +143,10 @@ async def validate_input_dataset_schema(
|
||||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||||
|
|
||||||
if dataset_def.dataset_schema not in EXPECTED_DATASET_SCHEMA[dataset_type]:
|
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||||
|
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||||
|
|
||||||
|
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dataset {dataset_id} does not have a correct input schema in {EXPECTED_DATASET_SCHEMA[dataset_type]}"
|
f"Dataset {dataset_id} does not have a correct input schema in {EXPECTED_DATASET_SCHEMA.dataset_type}"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue