This commit is contained in:
Xi Yan 2024-10-22 20:00:43 -07:00
parent 821810657f
commit aefa84e70a
5 changed files with 117 additions and 54 deletions

View file

@ -3,10 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
from typing import List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
@ -14,6 +14,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import MetaReferenceDatasetIOConfig
@ -57,6 +58,15 @@ class PandasDataframeDataset(BaseDataset):
else:
return self.df.iloc[idx].to_dict()
def validate_dataset_schema(self) -> None:
if self.df is None:
self.load()
print(self.dataset_def.dataset_schema)
# get columns names
# columns = self.df[self.dataset_def.dataset_schema.keys()]
print(self.df.columns)
def load(self) -> None:
if self.df is not None:
return
@ -105,6 +115,7 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_def: DatasetDef,
) -> None:
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.validate_dataset_schema()
self.dataset_infos[dataset_def.identifier] = DatasetInfo(
dataset_def=dataset_def,
dataset_impl=dataset_impl,