mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
address comments
This commit is contained in:
parent
2a992d4f05
commit
f39dcdec9d
2 changed files with 13 additions and 14 deletions
|
@ -10,7 +10,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
@ -20,6 +19,10 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_stack.apis.common.type_system import ParamType, StringType
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import DatasetFormat
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from torchtune.data._messages import (
|
||||
|
@ -36,15 +39,6 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
|
|||
from torchtune.modules.transforms import Transform
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
instruction = "instruction"
|
||||
input = "input"
|
||||
output = "output"
|
||||
text = "text"
|
||||
conversations = "conversations"
|
||||
messages = "messages"
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
tokenizer_type: Any
|
||||
|
@ -191,7 +185,6 @@ async def validate_input_dataset_schema(
|
|||
else:
|
||||
dataset_schema = dataset_def.dataset_schema
|
||||
|
||||
if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
||||
)
|
||||
validate_dataset_schema(
|
||||
dataset_schema, getattr(EXPECTED_DATASET_SCHEMA, dataset_type)
|
||||
)
|
||||
|
|
|
@ -23,6 +23,12 @@ class ColumnName(Enum):
|
|||
completion_input = "completion_input"
|
||||
generated_answer = "generated_answer"
|
||||
context = "context"
|
||||
instruction = "instruction"
|
||||
input = "input"
|
||||
output = "output"
|
||||
text = "text"
|
||||
conversations = "conversations"
|
||||
messages = "messages"
|
||||
|
||||
|
||||
VALID_SCHEMAS_FOR_SCORING = [
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue