generator + scorer Api for MMLU

This commit is contained in:
Xi Yan 2024-10-13 23:27:02 -07:00
parent fb565dfb06
commit a25aff290e
14 changed files with 618 additions and 131 deletions

View file

@ -25,23 +25,27 @@ class CustomDataset(BaseDataset[DictSample]):
self.load()
return (DictSample(data=x) for x in self.dataset)
def __str__(self):
def __str__(self) -> str:
return f"CustomDataset({self.config})"
def __len__(self):
def __len__(self) -> int:
if not self.dataset:
self.load()
return len(self.dataset)
def load(self):
def load(self, n_samples: Optional[int] = None) -> None:
if self.dataset:
return
# TODO: better support w/ data url
if self.config.url.endswith(".csv"):
df = pandas.read_csv(self.config.url)
elif self.config.url.endswith(".xlsx"):
df = pandas.read_excel(self.config.url)
if n_samples is not None:
df = df.sample(n=n_samples)
self.dataset = Dataset.from_pandas(df)