eleuther custom tasks

This commit is contained in:
Xi Yan 2024-10-08 23:22:50 -07:00
parent b87bdd0176
commit 0919072a33
7 changed files with 290 additions and 10 deletions

View file

@ -6,12 +6,15 @@
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403
import os
import random import random
from pathlib import Path
import lm_eval import lm_eval
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.evaluator import evaluate, get_task_list from lm_eval.evaluator import evaluate, get_task_list
from lm_eval.tasks import get_task_dict, TaskManager from lm_eval.tasks import get_task_dict, TaskManager
from termcolor import cprint
from .config import EleutherEvalsImplConfig # noqa from .config import EleutherEvalsImplConfig # noqa
@ -20,10 +23,12 @@ class EleutherEvalsWrapper(LM):
def __init__( def __init__(
self, self,
inference_api: Inference, inference_api: Inference,
model: str,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.inference_api = inference_api self.inference_api = inference_api
self.model = model
self.tokenizer = None self.tokenizer = None
self.tokenized_requests = False self.tokenized_requests = False
self.kwargs = kwargs self.kwargs = kwargs
@ -83,13 +88,29 @@ class EleutherEvalsWrapper(LM):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
return NotImplementedError("Not implemented") res = []
for req in requests:
print("generation for msg: ", req.args[0])
response = self.inference_api.chat_completion(
model=self.model,
messages=[
{
"role": "user",
"content": req.args[0],
}
],
stream=False,
)
print(response)
res.append(response.completion_message)
print(response)
return res
class EleutherEvalsAdapter(Evals): class EleutherEvalsAdapter(Evals):
def __init__(self, config: EleutherEvalsImplConfig, inference_api: Inference): def __init__(self, config: EleutherEvalsImplConfig, inference_api: Inference):
self.inference_api = inference_api self.inference_api = inference_api
self.eluther_wrapper = EleutherEvalsWrapper(inference_api)
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -103,16 +124,34 @@ class EleutherEvalsAdapter(Evals):
dataset: str, dataset: str,
task: str, task: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
task_manager = TaskManager() eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model)
cprint(f"Eleuther Evals: {model} {dataset} {task}", "red")
task = "meta_mmlu_pro_instruct"
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
print(current_dir)
task_manager = TaskManager(
include_path=str(current_dir / "tasks"),
)
task_dict = get_task_dict(task, task_manager) task_dict = get_task_dict(task, task_manager)
cprint(task_dict, "blue")
task_types = set([t.task.OUTPUT_TYPE for t in get_task_list(task_dict)]) task_types = set([t.task.OUTPUT_TYPE for t in get_task_list(task_dict)])
cprint(task_types, "cyan")
output = evaluate( output = evaluate(
self.eluther_wrapper, eluther_wrapper,
task_dict, task_dict,
limit=2, limit=1,
) )
formatted_output = lm_eval.utils.make_table(output) formatted_output = lm_eval.utils.make_table(output)
cprint(formatted_output, "green")
return EvaluateResponse( return EvaluateResponse(
metrics={ metrics={
"metrics_table": formatted_output, "metrics_table": formatted_output,

View file

@ -0,0 +1,32 @@
task: meta_ifeval
dataset_path: parquet
dataset_kwargs:
data_files: ./work_dir/joined_ifeval.parquet
output_type: generate_until
test_split: train
num_fewshot: 0
doc_to_text: prompt
doc_to_target: 0
generation_kwargs:
until: []
do_sample: false
temperature: 0.0
max_gen_toks: 1280
process_results: !function utils.process_results
metric_list:
- metric: prompt_level_strict_acc
aggregation: mean
higher_is_better: true
- metric: inst_level_strict_acc
aggregation: !function utils.agg_inst_level_acc
higher_is_better: true
- metric: prompt_level_loose_acc
aggregation: mean
higher_is_better: true
- metric: inst_level_loose_acc
aggregation: !function utils.agg_inst_level_acc
higher_is_better: true
metadata:
version: 2.0
fewshot_config:
sampler: first_n

View file

@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import dataclasses
from typing import Dict, Optional, Union
from lm_eval.tasks.ifeval import instructions_registry
@dataclasses.dataclass
class InputExample:
key: int
instruction_id_list: list[str]
prompt: str
kwargs: list[Dict[str, Optional[Union[str, int]]]]
@dataclasses.dataclass
class OutputExample:
instruction_id_list: list[str]
prompt: str
response: str
follow_all_instructions: bool
follow_instruction_list: list[bool]
def test_instruction_following_strict(
inp,
response,
):
"""Tests response to see if instructions are followed."""
instruction_list = inp.instruction_id_list
is_following_list = []
for index, instruction_id in enumerate(instruction_list):
instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
# Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
instruction.build_description(**kwargs)
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=inp.prompt)
if response.strip() and instruction.check_following(response):
is_following_list.append(True)
else:
is_following_list.append(False)
return OutputExample(
instruction_id_list=inp.instruction_id_list,
prompt=inp.prompt,
response=response,
follow_all_instructions=all(is_following_list),
follow_instruction_list=is_following_list,
)
def test_instruction_following_loose(
inp,
response,
):
"""Tests response for an upper bound for following instructions."""
r = response.split("\n")
response_remove_first = "\n".join(r[1:]).strip()
response_remove_last = "\n".join(r[:-1]).strip()
response_remove_both = "\n".join(r[1:-1]).strip()
revised_response = response.replace("*", "")
revised_response_remove_first = response_remove_first.replace("*", "")
revised_response_remove_last = response_remove_last.replace("*", "")
revised_response_remove_both = response_remove_both.replace("*", "")
all_responses = [
response,
revised_response,
response_remove_first,
response_remove_last,
response_remove_both,
revised_response_remove_first,
revised_response_remove_last,
revised_response_remove_both,
]
instruction_list = inp.instruction_id_list
is_following_list = []
for index, instruction_id in enumerate(instruction_list):
instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
# Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
instruction.build_description(**kwargs)
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=inp.prompt)
is_following = False
for r in all_responses:
if r.strip() and instruction.check_following(r):
is_following = True
break
is_following_list.append(is_following)
return OutputExample(
instruction_id_list=inp.instruction_id_list,
prompt=inp.prompt,
response=response,
follow_all_instructions=all(is_following_list),
follow_instruction_list=is_following_list,
)
def process_results(doc, results):
new_kwargs = []
for item in doc["kwargs"]:
if item["nth_paragraph"]:
item["nth_paragraph"] = int(item["nth_paragraph"])
new_kwargs.append(item)
inp = InputExample(
key=doc["key"],
instruction_id_list=doc["instruction_id_list"],
prompt=doc["prompt"],
kwargs=new_kwargs,
)
response = results[0]
out_strict = test_instruction_following_strict(inp, response)
out_loose = test_instruction_following_loose(inp, response)
return {
"prompt_level_strict_acc": out_strict.follow_all_instructions,
"inst_level_strict_acc": out_strict.follow_instruction_list,
"prompt_level_loose_acc": out_loose.follow_all_instructions,
"inst_level_loose_acc": out_loose.follow_instruction_list,
}
def agg_inst_level_acc(items):
flat_items = [item for sublist in items for item in sublist]
inst_level_acc = sum(flat_items) / len(flat_items)
return inst_level_acc

View file

@ -0,0 +1,29 @@
task: meta_mmlu_pro_instruct
dataset_path: meta-llama/Llama-3.1-8B-Instruct-evals
dataset_name: Llama-3.1-8B-Instruct-evals__mmlu_pro__details
test_split: latest
output_type: generate_until
process_docs: !function utils.process_docs
doc_to_text: !function utils.doc_to_text
doc_to_target: gold
filter_list:
- name: "strict-match"
filter:
- function: "regex"
group_select: -1
regex_pattern: 'best answer is ([A-Z])'
- function: "take_first"
generation_kwargs:
until: []
do_sample: false
temperature: 0
max_gen_toks: 1024
num_fewshot: 0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
metadata:
version: 1.0

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import datasets
def doc_to_text(doc: dict) -> str:
return doc["input_final_prompts"][0]
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc: dict) -> dict:
out_doc = {
"problem": doc["input_question"],
"gold": doc["input_correct_responses"][0],
}
return out_doc
dataset = dataset.select_columns(
[
"input_question",
"input_correct_responses",
"input_final_prompts",
"is_correct",
"input_question_hash",
"input_choice_list",
"output_prediction_text",
],
)
dataset = dataset.rename_column("is_correct", "previously_is_correct")
dataset = dataset.map(_process_doc)
return dataset.map(_process_doc)

View file

@ -152,7 +152,7 @@ def severity(levelname: str) -> LogSeverity:
elif levelname == "INFO": elif levelname == "INFO":
return LogSeverity.INFO return LogSeverity.INFO
elif levelname == "WARNING": elif levelname == "WARNING":
return LogSeverity.WARNING return LogSeverity.WARN
elif levelname == "ERROR": elif levelname == "ERROR":
return LogSeverity.ERROR return LogSeverity.ERROR
elif levelname == "CRITICAL": elif levelname == "CRITICAL":

View file

@ -12,12 +12,12 @@ apis_to_serve:
- safety - safety
- evals - evals
api_providers: api_providers:
# evals:
# provider_type: eleuther
# config: {}
evals: evals:
provider_type: meta-reference provider_type: eleuther
config: {} config: {}
# evals:
# provider_type: meta-reference
# config: {}
inference: inference:
providers: providers:
- meta-reference - meta-reference