mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: add todo for schema validation (#1991)
# What does this PR do? Change validation to TODO same as was done [here](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/eval/meta_reference/eval.py#L87) until validation can be implemented Closes #1849 ## Test Plan Signed-off-by: Kevin <kpostlet@redhat.com>
This commit is contained in:
parent
fe9b5ef08b
commit
2aca7265b3
2 changed files with 3 additions and 28 deletions
|
@ -17,10 +17,8 @@ from llama_stack.apis.common.type_system import (
|
||||||
DialogType,
|
DialogType,
|
||||||
StringType,
|
StringType,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.datasets import Datasets
|
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
ColumnName,
|
ColumnName,
|
||||||
validate_dataset_schema,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||||
|
@ -36,21 +34,3 @@ EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def validate_input_dataset_schema(
|
|
||||||
datasets_api: Datasets,
|
|
||||||
dataset_id: str,
|
|
||||||
dataset_type: str,
|
|
||||||
) -> None:
|
|
||||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
|
||||||
if not dataset_def:
|
|
||||||
raise ValueError(f"Dataset {dataset_id} does not exist.")
|
|
||||||
|
|
||||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
|
||||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
|
||||||
|
|
||||||
if dataset_type not in EXPECTED_DATASET_SCHEMA:
|
|
||||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
|
||||||
|
|
||||||
validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type])
|
|
||||||
|
|
|
@ -48,9 +48,6 @@ from llama_stack.apis.post_training import (
|
||||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.inline.post_training.common.validator import (
|
|
||||||
validate_input_dataset_schema,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||||
TorchtuneCheckpointer,
|
TorchtuneCheckpointer,
|
||||||
|
@ -348,11 +345,9 @@ class LoraFinetuningSingleDevice:
|
||||||
all_rows = await fetch_rows(dataset_id)
|
all_rows = await fetch_rows(dataset_id)
|
||||||
rows = all_rows.data
|
rows = all_rows.data
|
||||||
|
|
||||||
await validate_input_dataset_schema(
|
# TODO (xiyan): validate dataset schema
|
||||||
datasets_api=self.datasets_api,
|
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
dataset_id=dataset_id,
|
|
||||||
dataset_type=self._data_format.value,
|
|
||||||
)
|
|
||||||
data_transform = await utils.get_data_transform(self._data_format)
|
data_transform = await utils.get_data_transform(self._data_format)
|
||||||
ds = SFTDataset(
|
ds = SFTDataset(
|
||||||
rows,
|
rows,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue