wip add datatypes

This commit is contained in:
Xi Yan 2024-10-10 19:56:19 -07:00
parent 99ed1425fc
commit 9816c9aae6
5 changed files with 175 additions and 57 deletions

View file

@ -16,17 +16,14 @@ class CustomDataset(BaseDataset[DictSample]):
self.dataset = None
self.index = 0
def __iter__(self) -> Iterator[DictSample]:
return self
@property
def dataset_id(self) -> str:
return self.config.identifier
def __next__(self) -> DictSample:
def __iter__(self) -> Iterator[DictSample]:
if not self.dataset:
self.load()
if self.index >= len(self.dataset):
raise StopIteration
sample = DictSample(data=self.dataset[self.index])
self.index += 1
return sample
return (DictSample(data=x) for x in self.dataset)
def __str__(self):
return f"CustomDataset({self.config})"
@ -53,19 +50,15 @@ class HuggingfaceDataset(BaseDataset[DictSample]):
super().__init__()
self.config = config
self.dataset = None
self.index = 0
@property
def dataset_id(self) -> str:
return self.config.identifier
def __iter__(self) -> Iterator[DictSample]:
return self
def __next__(self) -> DictSample:
if not self.dataset:
self.load()
if self.index >= len(self.dataset):
raise StopIteration
sample = DictSample(data=self.dataset[self.index])
self.index += 1
return sample
return (DictSample(data=x) for x in self.dataset)
def __str__(self):
return f"HuggingfaceDataset({self.config})"
@ -79,12 +72,3 @@ class HuggingfaceDataset(BaseDataset[DictSample]):
if self.dataset:
return
self.dataset = load_dataset(self.config.dataset_name, **self.config.kwargs)
# parsed = urlparse(self.url)
# if parsed.scheme != "hf":
# raise ValueError(f"Unknown HF dataset: {self.url}")
# query = parse_qs(parsed.query)
# query = {k: v[0] for k, v in query.items()}
# path = parsed.netloc
# self.dataset = load_dataset(path, **query)