This commit is contained in:
Xi Yan 2024-10-03 17:31:46 -07:00
parent 8339b2cef3
commit 4f07aca309
6 changed files with 188 additions and 1 deletions

View file

@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
class BaseTask(ABC):
"""
Base class for all evaluation tasks.
Each task needs to implement the following methods:
- F1: preprocess_sample(self)
- F2: postprocess_sample(self)
- F3: score_sample(self)
"""
def __init__(self, dataset, *args, **kwargs):
super().__init__(*args, **kwargs)
self._name = self.__class__.__name__
self.dataset = dataset
@abstractmethod
def preprocess_sample(self, sample):
"""
F1: preprocess sample
"""
raise NotImplementedError()
@abstractmethod
def postprocess_sample(self, sample):
"""
F2: postprocess sample
"""
raise NotImplementedError()
@abstractmethod
def score_sample(self, sample, ground_truth):
"""
F3: score sample
"""
raise NotImplementedError()
def preprocess(self):
pass
def postprocess(self):
pass
def score(self, generation):
pass
class MMLUTask(BaseTask):
"""
MMLU Task. Each task needs to implement the following methods:
- F1: preprocess_sample(self)
- F2: postprocess_sample(self)
- F3: score_sample(self)
"""
def __init__(self, dataset, *args, **kwargs):
super().__init__(dataset, *args, **kwargs)
def preprocess_sample(self, sample):
"""
F1: preprocess sample
"""
pass
def postprocess_sample(self, sample):
"""
F2: postprocess sample
"""
pass
def score_sample(self, sample):
"""
F3: score sample
"""
pass

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .tasks import * # noqa: F403
# TODO: make this into a config based registry
TASKS_REGISTRY = {
"mmlu": MMLUTask,
}
def get_task(task_id: str, dataset):
task_impl = TASKS_REGISTRY[task_id]
return task_impl(dataset)