This commit is contained in:
Xi Yan 2024-10-04 00:25:57 -07:00
parent 4f07aca309
commit 3cbe3a72e8
10 changed files with 230 additions and 76 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from .task import BaseTask
QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
{Question}
A) {A}
B) {B}
C) {C}
D) {D}
""".strip()
MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:",
r"Answer\s*:", # Korean invisible character
r"উত্তর\s*:",
r"उत्तर\s*:",
r"উত্তরঃ",
r"উত্তর\s*:",
r"Antwort\s*:",
r"답변\s*:",
r"정답\s*:",
r"\s*:",
r"答案\s*",
r"答案\s*:",
r"\s*",
r"\s*:",
r"答复\s*",
r"答曰\s*",
r"الإجابة:",
r"الجواب:",
r"إجابة:",
r"الإجابة النهائية:",
r"الإجابة الصحيحة:",
r"الإجابة الصحيحة هي:",
r"الإجابة هي:",
r"Respuesta\s*:",
r"Risposta\s*:",
r"答え\s*:",
r"答え\s*",
r"回答\s*:",
r"回答\s*",
r"解答\s*:",
r"Jawaban\s*:",
r"Réponse\s*:",
r"Resposta\s*:",
r"Jibu\s*:",
r"Idahun\s*:",
r"Ìdáhùn\s*:",
r"Idáhùn\s*:",
r"Àmọ̀nà\s*:",
r"Àdáhùn\s*:",
r"Ànúgọ\s*:",
r"Àṣàyàn\s*:",
]
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
)
def normalize_response(response: str) -> str:
"""
Normalize the response by removing markdown and LaTeX formatting that may prevent a match.
"""
return (
response.replace("**", "")
.replace("$\\boxed{", "")
.replace("}$", "")
.replace("\\$", "")
.replace("$\\text{", "")
.replace("$", "")
.replace("\\mathrm{", "")
.replace("\\{", "")
.replace("\\text", "")
.replace("\\(", "")
.replace("\\mathbf{", "")
.replace("{", "")
.replace("\\boxed", "")
)
def normalize_extracted_answer(extracted_answer: str) -> str:
return (
# In arabic these are the letters used for A-D in multiple choice questions
extracted_answer.replace("أ", " A")
.replace("ب", " B")
.replace("ج", " C")
.replace("د", " D")
# In Bengali these are the letters used for A-D in multiple choice questions
.replace("", " A")
.replace("", " B")
.replace("", " C")
.replace("", " D")
# In Japanese these are the letters sometimes used for A-D in multiple choice questions
.replace("", " A")
.replace("", " B")
.replace("", " C")
.replace("", " D")
.strip()
)
class MMLUTask(BaseTask):
"""
MMLU Task.
"""
def __init__(self, dataset, *args, **kwargs):
super().__init__(dataset, *args, **kwargs)
def preprocess_sample(self, sample):
"""
F1: preprocess sample
"""
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
return {
"role": "user",
"content": content,
}
def postprocess_sample(self, sample):
"""
F2: postprocess sample
"""
normalized = normalize_response(sample)
return normalized
def score_sample(self, sample, expected):
"""
F3: score sample
"""
extracted_answer = None
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
match = re.search(regex, sample)
if match:
extracted_answer = normalize_extracted_answer(match.group(1))
break
score = (
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
)
return score

View file

@ -8,8 +8,7 @@ from abc import ABC, abstractmethod
class BaseTask(ABC):
"""
Base class for all evaluation tasks.
Each task needs to implement the following methods:
Base class for all evaluation tasks. Each task needs to implement the following methods:
- F1: preprocess_sample(self)
- F2: postprocess_sample(self)
- F3: score_sample(self)
@ -42,40 +41,13 @@ class BaseTask(ABC):
raise NotImplementedError()
def preprocess(self):
pass
return [self.preprocess_sample(sample) for sample in self.dataset]
def postprocess(self):
pass
def postprocess(self, generation):
return [self.postprocess_sample(sample) for sample in generation]
def score(self, generation):
pass
class MMLUTask(BaseTask):
"""
MMLU Task. Each task needs to implement the following methods:
- F1: preprocess_sample(self)
- F2: postprocess_sample(self)
- F3: score_sample(self)
"""
def __init__(self, dataset, *args, **kwargs):
super().__init__(dataset, *args, **kwargs)
def preprocess_sample(self, sample):
"""
F1: preprocess sample
"""
pass
def postprocess_sample(self, sample):
"""
F2: postprocess sample
"""
pass
def score_sample(self, sample):
"""
F3: score sample
"""
pass
def score(self, postprocessed):
return [
self.score_sample(sample, ground_truth)
for sample, ground_truth in zip(postprocessed, self.dataset)
]

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .tasks import * # noqa: F403
from .mmlu_task import MMLUTask
# TODO: make this into a config based registry
TASKS_REGISTRY = {