wip add datatypes

This commit is contained in:
Xi Yan 2024-10-10 19:56:19 -07:00
parent 99ed1425fc
commit 9816c9aae6
5 changed files with 175 additions and 57 deletions

View file

@ -111,7 +111,14 @@ def normalize_extracted_answer(extracted_answer: str) -> str:
)
class MMLUTask(BaseTask):
class MMLUTask(
BaseTask[
DictSample,
InferencePreprocessedSample,
InferencePredictionSample,
InferencePostprocessedSample,
]
):
"""
MMLU Task.
"""
@ -120,6 +127,7 @@ class MMLUTask(BaseTask):
super().__init__(dataset, *args, **kwargs)
def preprocess_sample(self, sample):
print(sample)
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
return {
"role": "user",