diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py b/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py index ee51adf35..c22ddfef9 100644 --- a/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py +++ b/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py @@ -6,12 +6,15 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403 +import os import random +from pathlib import Path import lm_eval from lm_eval.api.model import LM from lm_eval.evaluator import evaluate, get_task_list from lm_eval.tasks import get_task_dict, TaskManager +from termcolor import cprint from .config import EleutherEvalsImplConfig # noqa @@ -20,10 +23,12 @@ class EleutherEvalsWrapper(LM): def __init__( self, inference_api: Inference, + model: str, **kwargs, ): super().__init__(**kwargs) self.inference_api = inference_api + self.model = model self.tokenizer = None self.tokenized_requests = False self.kwargs = kwargs @@ -83,13 +88,29 @@ class EleutherEvalsWrapper(LM): raise NotImplementedError("No support for logits.") def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: - return NotImplementedError("Not implemented") + res = [] + for req in requests: + print("generation for msg: ", req.args[0]) + response = self.inference_api.chat_completion( + model=self.model, + messages=[ + { + "role": "user", + "content": req.args[0], + } + ], + stream=False, + ) + print(response) + res.append(response.completion_message) + + print(response) + return res class EleutherEvalsAdapter(Evals): def __init__(self, config: EleutherEvalsImplConfig, inference_api: Inference): self.inference_api = inference_api - self.eluther_wrapper = EleutherEvalsWrapper(inference_api) async def initialize(self) -> None: pass @@ -103,16 +124,34 @@ class EleutherEvalsAdapter(Evals): dataset: str, task: str, ) -> EvaluateResponse: - task_manager = TaskManager() + eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model) + + cprint(f"Eleuther Evals: {model} {dataset} {task}", "red") + + task = "meta_mmlu_pro_instruct" + current_dir = Path(os.path.dirname(os.path.abspath(__file__))) + print(current_dir) + + task_manager = TaskManager( + include_path=str(current_dir / "tasks"), + ) + task_dict = get_task_dict(task, task_manager) + cprint(task_dict, "blue") + task_types = set([t.task.OUTPUT_TYPE for t in get_task_list(task_dict)]) + cprint(task_types, "cyan") output = evaluate( - self.eluther_wrapper, + eluther_wrapper, task_dict, - limit=2, + limit=1, ) + formatted_output = lm_eval.utils.make_table(output) + + cprint(formatted_output, "green") + return EvaluateResponse( metrics={ "metrics_table": formatted_output, diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/tasks/ifeval/ifeval.yaml b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/ifeval/ifeval.yaml new file mode 100644 index 000000000..c7196d16d --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/ifeval/ifeval.yaml @@ -0,0 +1,32 @@ +task: meta_ifeval +dataset_path: parquet +dataset_kwargs: + data_files: ./work_dir/joined_ifeval.parquet +output_type: generate_until +test_split: train +num_fewshot: 0 +doc_to_text: prompt +doc_to_target: 0 +generation_kwargs: + until: [] + do_sample: false + temperature: 0.0 + max_gen_toks: 1280 +process_results: !function utils.process_results +metric_list: + - metric: prompt_level_strict_acc + aggregation: mean + higher_is_better: true + - metric: inst_level_strict_acc + aggregation: !function utils.agg_inst_level_acc + higher_is_better: true + - metric: prompt_level_loose_acc + aggregation: mean + higher_is_better: true + - metric: inst_level_loose_acc + aggregation: !function utils.agg_inst_level_acc + higher_is_better: true +metadata: + version: 2.0 +fewshot_config: + sampler: first_n diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/tasks/ifeval/utils.py b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/ifeval/utils.py new file mode 100644 index 000000000..5c7c92494 --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/ifeval/utils.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import dataclasses +from typing import Dict, Optional, Union + +from lm_eval.tasks.ifeval import instructions_registry + + +@dataclasses.dataclass +class InputExample: + key: int + instruction_id_list: list[str] + prompt: str + kwargs: list[Dict[str, Optional[Union[str, int]]]] + + +@dataclasses.dataclass +class OutputExample: + instruction_id_list: list[str] + prompt: str + response: str + follow_all_instructions: bool + follow_instruction_list: list[bool] + + +def test_instruction_following_strict( + inp, + response, +): + """Tests response to see if instructions are followed.""" + instruction_list = inp.instruction_id_list + is_following_list = [] + + for index, instruction_id in enumerate(instruction_list): + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] + instruction = instruction_cls(instruction_id) + + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. + kwargs = {k: v for k, v in inp.kwargs[index].items() if v} + instruction.build_description(**kwargs) + args = instruction.get_instruction_args() + if args and "prompt" in args: + instruction.build_description(prompt=inp.prompt) + + if response.strip() and instruction.check_following(response): + is_following_list.append(True) + else: + is_following_list.append(False) + + return OutputExample( + instruction_id_list=inp.instruction_id_list, + prompt=inp.prompt, + response=response, + follow_all_instructions=all(is_following_list), + follow_instruction_list=is_following_list, + ) + + +def test_instruction_following_loose( + inp, + response, +): + """Tests response for an upper bound for following instructions.""" + r = response.split("\n") + response_remove_first = "\n".join(r[1:]).strip() + response_remove_last = "\n".join(r[:-1]).strip() + response_remove_both = "\n".join(r[1:-1]).strip() + revised_response = response.replace("*", "") + revised_response_remove_first = response_remove_first.replace("*", "") + revised_response_remove_last = response_remove_last.replace("*", "") + revised_response_remove_both = response_remove_both.replace("*", "") + all_responses = [ + response, + revised_response, + response_remove_first, + response_remove_last, + response_remove_both, + revised_response_remove_first, + revised_response_remove_last, + revised_response_remove_both, + ] + instruction_list = inp.instruction_id_list + is_following_list = [] + + for index, instruction_id in enumerate(instruction_list): + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] + instruction = instruction_cls(instruction_id) + + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. + kwargs = {k: v for k, v in inp.kwargs[index].items() if v} + instruction.build_description(**kwargs) + args = instruction.get_instruction_args() + if args and "prompt" in args: + instruction.build_description(prompt=inp.prompt) + + is_following = False + for r in all_responses: + if r.strip() and instruction.check_following(r): + is_following = True + break + + is_following_list.append(is_following) + + return OutputExample( + instruction_id_list=inp.instruction_id_list, + prompt=inp.prompt, + response=response, + follow_all_instructions=all(is_following_list), + follow_instruction_list=is_following_list, + ) + + +def process_results(doc, results): + new_kwargs = [] + for item in doc["kwargs"]: + if item["nth_paragraph"]: + item["nth_paragraph"] = int(item["nth_paragraph"]) + new_kwargs.append(item) + inp = InputExample( + key=doc["key"], + instruction_id_list=doc["instruction_id_list"], + prompt=doc["prompt"], + kwargs=new_kwargs, + ) + response = results[0] + + out_strict = test_instruction_following_strict(inp, response) + out_loose = test_instruction_following_loose(inp, response) + + return { + "prompt_level_strict_acc": out_strict.follow_all_instructions, + "inst_level_strict_acc": out_strict.follow_instruction_list, + "prompt_level_loose_acc": out_loose.follow_all_instructions, + "inst_level_loose_acc": out_loose.follow_instruction_list, + } + + +def agg_inst_level_acc(items): + flat_items = [item for sublist in items for item in sublist] + inst_level_acc = sum(flat_items) / len(flat_items) + return inst_level_acc diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/tasks/mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml new file mode 100644 index 000000000..1ec3c107d --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml @@ -0,0 +1,29 @@ +task: meta_mmlu_pro_instruct +dataset_path: meta-llama/Llama-3.1-8B-Instruct-evals +dataset_name: Llama-3.1-8B-Instruct-evals__mmlu_pro__details +test_split: latest +output_type: generate_until +process_docs: !function utils.process_docs +doc_to_text: !function utils.doc_to_text +doc_to_target: gold +filter_list: + - name: "strict-match" + filter: + - function: "regex" + group_select: -1 + regex_pattern: 'best answer is ([A-Z])' + - function: "take_first" +generation_kwargs: + until: [] + do_sample: false + temperature: 0 + max_gen_toks: 1024 +num_fewshot: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 1.0 diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/tasks/mmlu_pro/utils.py b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/mmlu_pro/utils.py new file mode 100644 index 000000000..e25717e98 --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/mmlu_pro/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import datasets + + +def doc_to_text(doc: dict) -> str: + return doc["input_final_prompts"][0] + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc: dict) -> dict: + out_doc = { + "problem": doc["input_question"], + "gold": doc["input_correct_responses"][0], + } + return out_doc + + dataset = dataset.select_columns( + [ + "input_question", + "input_correct_responses", + "input_final_prompts", + "is_correct", + "input_question_hash", + "input_choice_list", + "output_prediction_text", + ], + ) + dataset = dataset.rename_column("is_correct", "previously_is_correct") + dataset = dataset.map(_process_doc) + return dataset.map(_process_doc) diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 9fffc0f99..207064904 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -152,7 +152,7 @@ def severity(levelname: str) -> LogSeverity: elif levelname == "INFO": return LogSeverity.INFO elif levelname == "WARNING": - return LogSeverity.WARNING + return LogSeverity.WARN elif levelname == "ERROR": return LogSeverity.ERROR elif levelname == "CRITICAL": diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml index 78b1b32d7..b1caab8fc 100644 --- a/tests/examples/local-run.yaml +++ b/tests/examples/local-run.yaml @@ -12,12 +12,12 @@ apis_to_serve: - safety - evals api_providers: - # evals: - # provider_type: eleuther - # config: {} evals: - provider_type: meta-reference + provider_type: eleuther config: {} + # evals: + # provider_type: meta-reference + # config: {} inference: providers: - meta-reference