diff --git a/llama_stack/apis/dataset/dataset.py b/llama_stack/apis/dataset/dataset.py index 2fa8bb4e5..ba2cb8811 100644 --- a/llama_stack/apis/dataset/dataset.py +++ b/llama_stack/apis/dataset/dataset.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum +# from enum import Enum from typing import Any, Dict, Optional, Protocol from llama_models.llama3.api.datatypes import URL @@ -14,22 +14,12 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -@json_schema_type -class TrainEvalDatasetColumnType(Enum): - dialog = "dialog" - text = "text" - media = "media" - number = "number" - json = "json" - - @json_schema_type class TrainEvalDataset(BaseModel): """Dataset to be used for training or evaluating language models.""" - # TODO(ashwin): figure out if we need to add an enum for a "dataset type" - - columns: Dict[str, TrainEvalDatasetColumnType] + # unique identifier associated with the dataset + dataset_id: str content_url: URL metadata: Optional[Dict[str, Any]] = None diff --git a/llama_stack/apis/evals/client.py b/llama_stack/apis/evals/client.py new file mode 100644 index 000000000..ad4a47145 --- /dev/null +++ b/llama_stack/apis/evals/client.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json + +import fire +import httpx +from termcolor import cprint + +from .evals import * # noqa: F403 + + +class EvaluationClient(Evals): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_evals( + self, + model: str, + task: str, + dataset: Optional[str] = None, + eval_task_config: Optional[EvaluateTaskConfig] = None, + ) -> EvaluateResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/evals/run", + json={ + "model": model, + "task": task, + "dataset": dataset, + "eval_task_config": ( + json.loads(eval_task_config.json()) + if eval_task_config + else None + ), + }, + headers={"Content-Type": "application/json"}, + timeout=3600, + ) + response.raise_for_status() + return EvaluateResponse(**response.json()) + + +async def run_main(host: str, port: int): + client = EvaluationClient(f"http://{host}:{port}") + + # CustomDataset + response = await client.run_evals( + model="Llama3.1-8B-Instruct", + dataset="mmlu-simple-eval-en", + task="mmlu", + eval_task_config=EvaluateTaskConfig( + n_samples=2, + ), + ) + cprint(f"evaluate response={response}", "green") + + # Eleuther Eval Task + # response = await client.run_evals( + # model="Llama3.1-8B-Instruct", + # task="meta_mmlu_pro_instruct", + # # task="meta_ifeval", + # eval_task_config=EvaluateTaskConfig( + # n_samples=2, + # ) + # ) + # cprint(response.metrics["metrics_table"], "red") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/evals/evals.py b/llama_stack/apis/evals/evals.py index 0be2243ab..dbb1348a5 100644 --- a/llama_stack/apis/evals/evals.py +++ b/llama_stack/apis/evals/evals.py @@ -4,8 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum -from typing import List, Protocol +from typing import Protocol from llama_models.schema_utils import webmethod @@ -13,23 +12,6 @@ from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.dataset import * # noqa: F403 -from llama_stack.apis.common.training_types import * # noqa: F403 - - -class TextGenerationMetric(Enum): - perplexity = "perplexity" - rouge = "rouge" - bleu = "bleu" - - -class QuestionAnsweringMetric(Enum): - em = "em" - f1 = "f1" - - -class SummarizationMetric(Enum): - rouge = "rouge" - bleu = "bleu" class EvaluationJob(BaseModel): @@ -40,37 +22,21 @@ class EvaluationJobLogStream(BaseModel): job_uuid: str -class EvaluateTaskRequestCommon(BaseModel): - job_uuid: str - dataset: TrainEvalDataset - - checkpoint: Checkpoint - - # generation params +class EvaluateTaskConfig(BaseModel): + # num examples to evaluate, evaluate all if None + n_samples: Optional[int] = None + # model evaluation params sampling_params: SamplingParams = SamplingParams() @json_schema_type -class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon): - """Request to evaluate text generation.""" +class EvaluateResponse(BaseModel): + """Scores for evaluation.""" - metrics: List[TextGenerationMetric] + metrics: Dict[str, str] @json_schema_type -class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon): - """Request to evaluate question answering.""" - - metrics: List[QuestionAnsweringMetric] - - -@json_schema_type -class EvaluateSummarizationRequest(EvaluateTaskRequestCommon): - """Request to evaluate summarization.""" - - metrics: List[SummarizationMetric] - - class EvaluationJobStatusResponse(BaseModel): job_uuid: str @@ -82,41 +48,44 @@ class EvaluationJobArtifactsResponse(BaseModel): job_uuid: str -class Evaluations(Protocol): - @webmethod(route="/evaluate/text_generation/") - def evaluate_text_generation( +@json_schema_type +class EvaluationJobCreateResponse(BaseModel): + """Response to create a evaluation job.""" + + job_uuid: str + + +class Evals(Protocol): + @webmethod(route="/evals/run") + async def run_evals( self, - metrics: List[TextGenerationMetric], - ) -> EvaluationJob: ... + model: str, + task: str, + dataset: Optional[str] = None, + eval_task_config: Optional[EvaluateTaskConfig] = None, + ) -> EvaluateResponse: ... - @webmethod(route="/evaluate/question_answering/") - def evaluate_question_answering( - self, - metrics: List[QuestionAnsweringMetric], - ) -> EvaluationJob: ... + # @webmethod(route="/evals/jobs") + # def get_evaluation_jobs(self) -> List[EvaluationJob]: ... - @webmethod(route="/evaluate/summarization/") - def evaluate_summarization( - self, - metrics: List[SummarizationMetric], - ) -> EvaluationJob: ... + # @webmethod(route="/evals/job/create") + # async def create_evaluation_job( + # self, model: str, dataset: str, task: str + # ) -> EvaluationJob: ... - @webmethod(route="/evaluate/jobs") - def get_evaluation_jobs(self) -> List[EvaluationJob]: ... + # @webmethod(route="/evals/job/status") + # def get_evaluation_job_status( + # self, job_uuid: str + # ) -> EvaluationJobStatusResponse: ... - @webmethod(route="/evaluate/job/status") - def get_evaluation_job_status( - self, job_uuid: str - ) -> EvaluationJobStatusResponse: ... + # # sends SSE stream of logs + # @webmethod(route="/evals/job/logs") + # def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... - # sends SSE stream of logs - @webmethod(route="/evaluate/job/logs") - def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... + # @webmethod(route="/evals/job/cancel") + # def cancel_evaluation_job(self, job_uuid: str) -> None: ... - @webmethod(route="/evaluate/job/cancel") - def cancel_evaluation_job(self, job_uuid: str) -> None: ... - - @webmethod(route="/evaluate/job/artifacts") - def get_evaluation_job_artifacts( - self, job_uuid: str - ) -> EvaluationJobArtifactsResponse: ... + # @webmethod(route="/evals/job/artifacts") + # def get_evaluation_job_artifacts( + # self, job_uuid: str + # ) -> EvaluationJobArtifactsResponse: ... diff --git a/llama_stack/distribution/registry/__init__.py b/llama_stack/distribution/registry/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/registry/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/registry/datasets/__init__.py b/llama_stack/distribution/registry/datasets/__init__.py new file mode 100644 index 000000000..0b7a84395 --- /dev/null +++ b/llama_stack/distribution/registry/datasets/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# TODO: make these import config based +from .dataset import CustomDataset, HFDataset +from .dataset_registry import DatasetRegistry + +DATASETS_REGISTRY = { + "mmlu-simple-eval-en": CustomDataset( + name="mmlu_eval", + url="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + ), + "hellaswag": HFDataset( + name="hellaswag", + url="hf://hellaswag?split=validation&trust_remote_code=True", + ), +} + +for k, v in DATASETS_REGISTRY.items(): + DatasetRegistry.register(k, v) diff --git a/llama_stack/distribution/registry/datasets/dataset.py b/llama_stack/distribution/registry/datasets/dataset.py new file mode 100644 index 000000000..1a16a5c51 --- /dev/null +++ b/llama_stack/distribution/registry/datasets/dataset.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import ABC, abstractmethod +from urllib.parse import parse_qs, urlparse + +import pandas +from datasets import Dataset, load_dataset + + +class BaseDataset(ABC): + def __init__(self, name: str): + self.dataset = None + self.dataset_id = name + self.type = self.__class__.__name__ + + def __iter__(self): + return iter(self.dataset) + + @abstractmethod + def load(self): + pass + + +class CustomDataset(BaseDataset): + def __init__(self, name, url): + super().__init__(name) + self.url = url + + def load(self): + if self.dataset: + return + # TODO: better support w/ data url + if self.url.endswith(".csv"): + df = pandas.read_csv(self.url) + elif self.url.endswith(".xlsx"): + df = pandas.read_excel(self.url) + + self.dataset = Dataset.from_pandas(df) + + +class HFDataset(BaseDataset): + def __init__(self, name, url): + super().__init__(name) + self.url = url + + def load(self): + if self.dataset: + return + + parsed = urlparse(self.url) + + if parsed.scheme != "hf": + raise ValueError(f"Unknown HF dataset: {self.url}") + + query = parse_qs(parsed.query) + query = {k: v[0] for k, v in query.items()} + path = parsed.netloc + self.dataset = load_dataset(path, **query) diff --git a/llama_stack/distribution/registry/datasets/dataset_registry.py b/llama_stack/distribution/registry/datasets/dataset_registry.py new file mode 100644 index 000000000..9ddaa8bb7 --- /dev/null +++ b/llama_stack/distribution/registry/datasets/dataset_registry.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import AbstractSet, Dict + +from .dataset import BaseDataset + + +class DatasetRegistry: + _REGISTRY: Dict[str, BaseDataset] = {} + + @staticmethod + def names() -> AbstractSet[str]: + return DatasetRegistry._REGISTRY.keys() + + @staticmethod + def register(name: str, task: BaseDataset) -> None: + if name in DatasetRegistry._REGISTRY: + raise ValueError(f"Dataset {name} already exists.") + DatasetRegistry._REGISTRY[name] = task + + @staticmethod + def get_dataset(name: str) -> BaseDataset: + if name not in DatasetRegistry._REGISTRY: + raise ValueError(f"Dataset {name} not found.") + return DatasetRegistry._REGISTRY[name] + + @staticmethod + def reset() -> None: + DatasetRegistry._REGISTRY = {} diff --git a/llama_stack/distribution/registry/tasks/__init__.py b/llama_stack/distribution/registry/tasks/__init__.py new file mode 100644 index 000000000..01ccb18ae --- /dev/null +++ b/llama_stack/distribution/registry/tasks/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +# TODO: make these import config based +from llama_stack.providers.impls.meta_reference.evals.tasks.mmlu_task import MMLUTask +from .task_registry import TaskRegistry + +TaskRegistry.register( + "mmlu", + MMLUTask, +) diff --git a/llama_stack/distribution/registry/tasks/task.py b/llama_stack/distribution/registry/tasks/task.py new file mode 100644 index 000000000..a92e6241b --- /dev/null +++ b/llama_stack/distribution/registry/tasks/task.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from abc import ABC, abstractmethod + + +class BaseTask(ABC): + """ + A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods. + Base class for all evaluation tasks. Each task needs to implement the following methods: + - F1: preprocess_sample(self) + - F2: postprocess_sample(self) + - F3: score_sample(self) + """ + + def __init__(self, dataset, *args, **kwargs): + super().__init__(*args, **kwargs) + self._name = self.__class__.__name__ + self.dataset = dataset + + @abstractmethod + def preprocess_sample(self, sample): + raise NotImplementedError() + + @abstractmethod + def postprocess_sample(self, sample): + raise NotImplementedError() + + @abstractmethod + def score_sample(self, sample, ground_truth): + raise NotImplementedError() + + @abstractmethod + def aggregate_results(self, eval_results): + raise NotImplementedError() + + def preprocess(self): + return [self.preprocess_sample(sample) for sample in self.dataset] + + def postprocess(self, generation): + return [self.postprocess_sample(sample) for sample in generation] + + def score(self, postprocessed): + return [ + self.score_sample(sample, ground_truth) + for sample, ground_truth in zip(postprocessed, self.dataset) + ] diff --git a/llama_stack/distribution/registry/tasks/task_registry.py b/llama_stack/distribution/registry/tasks/task_registry.py new file mode 100644 index 000000000..063894e48 --- /dev/null +++ b/llama_stack/distribution/registry/tasks/task_registry.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import AbstractSet, Dict + +from .task import BaseTask + + +class TaskRegistry: + _REGISTRY: Dict[str, BaseTask] = {} + + @staticmethod + def names() -> AbstractSet[str]: + return TaskRegistry._REGISTRY.keys() + + @staticmethod + def register(name: str, task: BaseTask) -> None: + if name in TaskRegistry._REGISTRY: + raise ValueError(f"Task {name} already exists.") + TaskRegistry._REGISTRY[name] = task + + @staticmethod + def get_task(name: str) -> BaseTask: + if name not in TaskRegistry._REGISTRY: + raise ValueError(f"Task {name} not found.") + return TaskRegistry._REGISTRY[name] + + @staticmethod + def reset() -> None: + TaskRegistry._REGISTRY = {} diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index a05e08cd7..672a4ea60 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -12,6 +12,7 @@ from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.apis.agents import Agents +from llama_stack.apis.evals import Evals from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.memory import Memory @@ -38,6 +39,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.safety: Safety, Api.shields: Shields, Api.telemetry: Telemetry, + Api.evals: Evals, } diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 777cd855b..50ab0691b 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -28,6 +28,7 @@ class Api(Enum): models = "models" shields = "shields" memory_banks = "memory_banks" + evals = "evals" # built-in API inspect = "inspect" diff --git a/llama_stack/providers/impls/meta_reference/evals/__init__.py b/llama_stack/providers/impls/meta_reference/evals/__init__.py new file mode 100644 index 000000000..f4dd4b79d --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/evals/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import MetaReferenceEvalsImplConfig # noqa +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.distribution.datatypes import Api, ProviderSpec + + +async def get_provider_impl( + config: MetaReferenceEvalsImplConfig, deps: Dict[Api, ProviderSpec] +): + from .evals import MetaReferenceEvalsImpl + + impl = MetaReferenceEvalsImpl(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/meta_reference/evals/config.py b/llama_stack/providers/impls/meta_reference/evals/config.py new file mode 100644 index 000000000..05dee366e --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/evals/config.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class MetaReferenceEvalsImplConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/meta_reference/evals/evals.py b/llama_stack/providers/impls/meta_reference/evals/evals.py new file mode 100644 index 000000000..5f475c539 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/evals/evals.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.evals import * # noqa: F403 +from termcolor import cprint + +from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry + +from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry + +from .config import MetaReferenceEvalsImplConfig + + +class MetaReferenceEvalsImpl(Evals): + def __init__(self, config: MetaReferenceEvalsImplConfig, inference_api: Inference): + self.inference_api = inference_api + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_evals( + self, + model: str, + task: str, + dataset: Optional[str] = None, + eval_task_config: Optional[EvaluateTaskConfig] = None, + ) -> EvaluateResponse: + cprint( + f"model={model}, dataset={dataset}, task={task}, eval_task_config={eval_task_config}", + "red", + ) + if not dataset: + raise ValueError("dataset must be specified for mete-reference evals") + + dataset = DatasetRegistry.get_dataset(dataset) + dataset.load() + + task_impl = TaskRegistry.get_task(task)(dataset) + x1 = task_impl.preprocess() + + # TODO: replace w/ batch inference & async return eval job + generation_outputs = [] + if eval_task_config is None: + eval_task_config = EvaluateTaskConfig(n_samples=len(x1)) + if eval_task_config.n_samples is None or eval_task_config.n_samples > len(x1): + eval_task_config.n_samples = len(x1) + + print( + f"Eval generation start, generate on {eval_task_config.n_samples} samples" + ) + + for msg in x1[: eval_task_config.n_samples]: + print("generation for msg: ", msg) + response = await self.inference_api.chat_completion( + model=model, + messages=[msg], + stream=False, + ) + generation_outputs.append(response.completion_message.content) + + x2 = task_impl.postprocess(generation_outputs) + eval_results = task_impl.score(x2) + eval_response = task_impl.aggregate_results(eval_results) + return eval_response diff --git a/llama_stack/providers/impls/meta_reference/evals/tasks/__init__.py b/llama_stack/providers/impls/meta_reference/evals/tasks/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/evals/tasks/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py b/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py new file mode 100644 index 000000000..673a95379 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import re + +from llama_stack.apis.evals import * # noqa: F403 +from llama_stack.distribution.registry.tasks.task import BaseTask + +QUERY_TEMPLATE_MULTICHOICE = """ +Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + +MULTILINGUAL_ANSWER_REGEXES = [ + r"Answer\s*:", + r"Answer\s*:​​​​​​", # Korean invisible character + r"উত্তর\s*:", + r"उत्तर\s*:", + r"উত্তরঃ", + r"উত্তর\s*:", + r"Antwort\s*:", + r"답변\s*:", + r"정답\s*:", + r"답\s*:", + r"答案\s*:", + r"答案\s*:", + r"答\s*:", + r"答\s*:", + r"答复\s*:", + r"答曰\s*:", + r"الإجابة:", + r"الجواب:", + r"إجابة:", + r"الإجابة النهائية:", + r"الإجابة الصحيحة:", + r"الإجابة الصحيحة هي:", + r"الإجابة هي:", + r"Respuesta\s*:", + r"Risposta\s*:", + r"答え\s*:", + r"答え\s*:", + r"回答\s*:", + r"回答\s*:", + r"解答\s*:", + r"Jawaban\s*:", + r"Réponse\s*:", + r"Resposta\s*:", + r"Jibu\s*:", + r"Idahun\s*:", + r"Ìdáhùn\s*:", + r"Idáhùn\s*:", + r"Àmọ̀nà\s*:", + r"Àdáhùn\s*:", + r"Ànúgọ\s*:", + r"Àṣàyàn\s*:", +] + +MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( + r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" +) + + +def normalize_response(response: str) -> str: + """ + Normalize the response by removing markdown and LaTeX formatting that may prevent a match. + """ + + return ( + response.replace("**", "") + .replace("$\\boxed{", "") + .replace("}$", "") + .replace("\\$", "") + .replace("$\\text{", "") + .replace("$", "") + .replace("\\mathrm{", "") + .replace("\\{", "") + .replace("\\text", "") + .replace("\\(", "") + .replace("\\mathbf{", "") + .replace("{", "") + .replace("\\boxed", "") + ) + + +def normalize_extracted_answer(extracted_answer: str) -> str: + return ( + # In arabic these are the letters used for A-D in multiple choice questions + extracted_answer.replace("أ", " A") + .replace("ب", " B") + .replace("ج", " C") + .replace("د", " D") + # In Bengali these are the letters used for A-D in multiple choice questions + .replace("অ", " A") + .replace("ব", " B") + .replace("ড", " C") + .replace("ঢ", " D") + # In Japanese these are the letters sometimes used for A-D in multiple choice questions + .replace("A", " A") + .replace("B", " B") + .replace("C", " C") + .replace("D", " D") + .strip() + ) + + +class MMLUTask(BaseTask): + """ + MMLU Task. + """ + + def __init__(self, dataset, *args, **kwargs): + super().__init__(dataset, *args, **kwargs) + + def preprocess_sample(self, sample): + content = QUERY_TEMPLATE_MULTICHOICE.format(**sample) + return { + "role": "user", + "content": content, + } + + def postprocess_sample(self, sample): + normalized = normalize_response(sample) + return normalized + + def score_sample(self, sample, expected): + extracted_answer = None + for answer_regex in MULTILINGUAL_ANSWER_REGEXES: + regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) + match = re.search(regex, sample) + if match: + extracted_answer = normalize_extracted_answer(match.group(1)) + break + score = ( + 1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0 + ) + # TODO: generalize this into SingleEvalResult + return score + + def aggregate_results(self, eval_results): + return EvaluateResponse( + metrics={"score": str(sum(eval_results) / len(eval_results))} + ) diff --git a/llama_stack/providers/impls/third_party/evals/__init__.py b/llama_stack/providers/impls/third_party/evals/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/__init__.py b/llama_stack/providers/impls/third_party/evals/eleuther/__init__.py new file mode 100644 index 000000000..9886ed6d6 --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import EleutherEvalsImplConfig # noqa +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.distribution.datatypes import Api, ProviderSpec + + +async def get_provider_impl( + config: EleutherEvalsImplConfig, deps: Dict[Api, ProviderSpec] +): + from .eleuther import EleutherEvalsAdapter + + impl = EleutherEvalsAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/config.py b/llama_stack/providers/impls/third_party/evals/eleuther/config.py new file mode 100644 index 000000000..a9ab297b4 --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/config.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class EleutherEvalsImplConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py b/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py new file mode 100644 index 000000000..b9f9505e9 --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.evals import * # noqa: F403 +import os +import random +import threading +from pathlib import Path + +import lm_eval +import tqdm +from lm_eval.api.model import LM +from lm_eval.evaluator import evaluate, get_task_list +from lm_eval.tasks import get_task_dict, TaskManager +from termcolor import cprint + +from .config import EleutherEvalsImplConfig + + +# https://stackoverflow.com/questions/74703727/how-to-call-async-function-from-sync-funcion-and-get-result-while-a-loop-is-alr +# We will use another thread wih its own event loop to run the async api within sync function +_loop = asyncio.new_event_loop() +_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True) + + +class EleutherEvalsWrapper(LM): + def __init__( + self, + inference_api: Inference, + model: str, + **kwargs, + ): + super().__init__(**kwargs) + self.inference_api = inference_api + self.model = model + self.tokenizer = None + self.tokenized_requests = False + self.kwargs = kwargs + + @property + def eot_token_id(self): + raise NotImplementedError("Not implemented") + + @property + def max_length(self) -> int: + return NotImplementedError("Not implemented") + + @property + def max_gen_toks(self) -> int: + return NotImplementedError("Not implemented") + + @property + def batch_size(self): + # Isn't used because we override _loglikelihood_tokens + raise NotImplementedError("No support for logits.") + + @property + def device(self): + # Isn't used because we override _loglikelihood_tokens + raise NotImplementedError("No support for logits.") + + @property + def world_size(self): + return 1 + + def tok_encode(self, string: str) -> List[int]: + return NotImplementedError("Not implemented") + + def tok_decode(self, tokens: List[int]) -> str: + return NotImplementedError("Not implemented") + + def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False): + raise NotImplementedError("No support for logits.") + + def _model_call(self, inps): + # Isn't used because we override _loglikelihood_tokens + raise NotImplementedError() + + def _model_generate(self, context, max_length, eos_token_id): + # Isn't used because we override generate_until + raise NotImplementedError() + + def loglikelihood(self, requests, disable_tqdm: bool = False): + # TODO: implement inference completion with loglikelihood + res = [] + for req in requests: + res.append((-random.random(), False)) + + return res + + def loglikelihood_rolling(self, requests, disable_tqdm: bool = False): + raise NotImplementedError("No support for logits.") + + def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: + res = [] + if not _thr.is_alive(): + _thr.start() + for req in tqdm.tqdm(requests): + chat_completion_coro_fn = self.inference_api.chat_completion( + model=self.model, + messages=[ + { + "role": "user", + "content": req.args[0], + } + ], + stream=False, + ) + future = asyncio.run_coroutine_threadsafe(chat_completion_coro_fn, _loop) + response = future.result() + res.append(response.completion_message.content) + + return res + + +class EleutherEvalsAdapter(Evals): + def __init__(self, config: EleutherEvalsImplConfig, inference_api: Inference): + self.inference_api = inference_api + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_evals( + self, + model: str, + task: str, + dataset: Optional[str] = None, + eval_task_config: Optional[EvaluateTaskConfig] = None, + ) -> EvaluateResponse: + cprint(f"Eleuther Evals: {model} {dataset} {task}", "red") + + eluther_wrapper = EleutherEvalsWrapper(self.inference_api, model) + current_dir = Path(os.path.dirname(os.path.abspath(__file__))) + + # custom registry of harness tasks + task_manager = TaskManager( + include_path=str(current_dir / "tasks"), + ) + + task_dict = get_task_dict(task, task_manager) + cprint(task_dict, "blue") + + task_types = set([t.task.OUTPUT_TYPE for t in get_task_list(task_dict)]) + cprint(task_types, "cyan") + + output = evaluate( + eluther_wrapper, + task_dict, + limit=eval_task_config.n_samples, + ) + + formatted_output = lm_eval.utils.make_table(output) + + cprint(formatted_output, "green") + + return EvaluateResponse( + metrics={ + "metrics_table": formatted_output, + }, + ) diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_ifeval/ifeval.yaml b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_ifeval/ifeval.yaml new file mode 100644 index 000000000..e10277a31 --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_ifeval/ifeval.yaml @@ -0,0 +1,32 @@ +task: meta_ifeval +dataset_path: meta-llama/Llama-3.1-8B-Instruct-evals +dataset_name: Llama-3.1-8B-Instruct-evals__ifeval__strict__details +output_type: generate_until +test_split: latest +process_docs: !function utils.process_docs +num_fewshot: 0 +doc_to_text: prompt +doc_to_target: 0 +generation_kwargs: + until: [] + do_sample: false + temperature: 0.0 + max_gen_toks: 1280 +process_results: !function utils.process_results +metric_list: + - metric: prompt_level_strict_acc + aggregation: mean + higher_is_better: true + - metric: inst_level_strict_acc + aggregation: !function utils.agg_inst_level_acc + higher_is_better: true + - metric: prompt_level_loose_acc + aggregation: mean + higher_is_better: true + - metric: inst_level_loose_acc + aggregation: !function utils.agg_inst_level_acc + higher_is_better: true +metadata: + version: 2.0 +fewshot_config: + sampler: first_n diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_ifeval/utils.py b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_ifeval/utils.py new file mode 100644 index 000000000..aa171343f --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_ifeval/utils.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import dataclasses +from typing import Dict, Optional, Union + +import datasets + +from lm_eval.tasks.ifeval import instructions_registry + + +@dataclasses.dataclass +class InputExample: + key: int + instruction_id_list: list[str] + prompt: str + kwargs: list[Dict[str, Optional[Union[str, int]]]] + + +@dataclasses.dataclass +class OutputExample: + instruction_id_list: list[str] + prompt: str + response: str + follow_all_instructions: bool + follow_instruction_list: list[bool] + + +def test_instruction_following_strict( + inp, + response, +): + """Tests response to see if instructions are followed.""" + instruction_list = inp.instruction_id_list + is_following_list = [] + + for index, instruction_id in enumerate(instruction_list): + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] + instruction = instruction_cls(instruction_id) + + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. + kwargs = {k: v for k, v in inp.kwargs[index].items() if v} + instruction.build_description(**kwargs) + args = instruction.get_instruction_args() + if args and "prompt" in args: + instruction.build_description(prompt=inp.prompt) + + if response.strip() and instruction.check_following(response): + is_following_list.append(True) + else: + is_following_list.append(False) + + return OutputExample( + instruction_id_list=inp.instruction_id_list, + prompt=inp.prompt, + response=response, + follow_all_instructions=all(is_following_list), + follow_instruction_list=is_following_list, + ) + + +def test_instruction_following_loose( + inp, + response, +): + """Tests response for an upper bound for following instructions.""" + r = response.split("\n") + response_remove_first = "\n".join(r[1:]).strip() + response_remove_last = "\n".join(r[:-1]).strip() + response_remove_both = "\n".join(r[1:-1]).strip() + revised_response = response.replace("*", "") + revised_response_remove_first = response_remove_first.replace("*", "") + revised_response_remove_last = response_remove_last.replace("*", "") + revised_response_remove_both = response_remove_both.replace("*", "") + all_responses = [ + response, + revised_response, + response_remove_first, + response_remove_last, + response_remove_both, + revised_response_remove_first, + revised_response_remove_last, + revised_response_remove_both, + ] + instruction_list = inp.instruction_id_list + is_following_list = [] + + for index, instruction_id in enumerate(instruction_list): + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] + instruction = instruction_cls(instruction_id) + + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. + kwargs = {k: v for k, v in inp.kwargs[index].items() if v} + instruction.build_description(**kwargs) + args = instruction.get_instruction_args() + if args and "prompt" in args: + instruction.build_description(prompt=inp.prompt) + + is_following = False + for r in all_responses: + if r.strip() and instruction.check_following(r): + is_following = True + break + + is_following_list.append(is_following) + + return OutputExample( + instruction_id_list=inp.instruction_id_list, + prompt=inp.prompt, + response=response, + follow_all_instructions=all(is_following_list), + follow_instruction_list=is_following_list, + ) + + +def process_results(doc, results): + new_kwargs = [] + for item in doc["kwargs"]: + if item["nth_paragraph"]: + item["nth_paragraph"] = int(item["nth_paragraph"]) + new_kwargs.append(item) + inp = InputExample( + key=doc["key"], + instruction_id_list=doc["instruction_id_list"], + prompt=doc["prompt"], + kwargs=new_kwargs, + ) + response = results[0] + + out_strict = test_instruction_following_strict(inp, response) + out_loose = test_instruction_following_loose(inp, response) + + return { + "prompt_level_strict_acc": out_strict.follow_all_instructions, + "inst_level_strict_acc": out_strict.follow_instruction_list, + "prompt_level_loose_acc": out_loose.follow_all_instructions, + "inst_level_loose_acc": out_loose.follow_instruction_list, + } + + +def agg_inst_level_acc(items): + flat_items = [item for sublist in items for item in sublist] + inst_level_acc = sum(flat_items) / len(flat_items) + return inst_level_acc + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _get_question(example: dict) -> dict: + # get the question from the ifeval dataset + example["input_question"] = ( + eval( + example["input_question"] + .replace("null", "None") + .replace("true", "True") + .replace("false", "False") + )["dialog"][0]["body"] + .replace("Is it True that the first song", "Is it true that the first song") + .replace("Is the following True", "Is the following true") + ) + example["input_final_prompts"] = example["input_final_prompts"][0] + return example + + original_dataset_name = "wis-k/instruction-following-eval" + ifeval_data = datasets.load_dataset(original_dataset_name, split="train") + ifeval_df = ifeval_data.to_pandas() + ifeval_df = ifeval_df.rename(columns={"prompt": "input_question"}) + + meta_dataset = dataset.map(_get_question) + meta_df = meta_dataset.to_pandas() + + # join the two datasets on the input_question column + joined = meta_df.join(ifeval_df.set_index("input_question"), on="input_question") + joined = joined.rename(columns={"input_final_prompts": "prompt"}) + joined = joined.rename(columns={"is_correct": "previous_is_correct"}) + joined = datasets.Dataset.from_pandas(joined) + joined = joined.select_columns( + [ + "input_question", + "prompt", + "previous_is_correct", + "instruction_id_list", + "kwargs", + "output_prediction_text", + "key", + ] + ) + joined.rename_column("output_prediction_text", "previous_output_prediction_text") + return joined diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml new file mode 100644 index 000000000..1ec3c107d --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml @@ -0,0 +1,29 @@ +task: meta_mmlu_pro_instruct +dataset_path: meta-llama/Llama-3.1-8B-Instruct-evals +dataset_name: Llama-3.1-8B-Instruct-evals__mmlu_pro__details +test_split: latest +output_type: generate_until +process_docs: !function utils.process_docs +doc_to_text: !function utils.doc_to_text +doc_to_target: gold +filter_list: + - name: "strict-match" + filter: + - function: "regex" + group_select: -1 + regex_pattern: 'best answer is ([A-Z])' + - function: "take_first" +generation_kwargs: + until: [] + do_sample: false + temperature: 0 + max_gen_toks: 1024 +num_fewshot: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 1.0 diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_mmlu_pro/utils.py b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_mmlu_pro/utils.py new file mode 100644 index 000000000..6b8bc3e5b --- /dev/null +++ b/llama_stack/providers/impls/third_party/evals/eleuther/tasks/meta_mmlu_pro/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import datasets + + +def doc_to_text(doc: dict) -> str: + return doc["input_final_prompts"][0] + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc: dict) -> dict: + out_doc = { + "problem": doc["input_question"], + "gold": doc["input_correct_responses"][0], + } + return out_doc + + dataset = dataset.select_columns( + [ + "input_question", + "input_correct_responses", + "input_final_prompts", + "is_correct", + "input_question_hash", + "input_choice_list", + "output_prediction_text", + ], + ) + dataset = dataset.rename_column("is_correct", "previously_is_correct") + dataset = dataset.map(_process_doc) + return dataset diff --git a/llama_stack/providers/registry/evals.py b/llama_stack/providers/registry/evals.py new file mode 100644 index 000000000..8693ec603 --- /dev/null +++ b/llama_stack/providers/registry/evals.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.evals, + provider_type="meta-reference", + pip_packages=[ + "matplotlib", + "pillow", + "pandas", + "scikit-learn", + "datasets", + ], + module="llama_stack.providers.impls.meta_reference.evals", + config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig", + api_dependencies=[ + Api.inference, + ], + ), + InlineProviderSpec( + api=Api.evals, + provider_type="eleuther", + pip_packages=[ + "lm-eval", + ], + module="llama_stack.providers.impls.third_party.evals.eleuther", + config_class="llama_stack.providers.impls.third_party.evals.eleuther.EleutherEvalsImplConfig", + api_dependencies=[ + Api.inference, + ], + ), + ] diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 9fffc0f99..207064904 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -152,7 +152,7 @@ def severity(levelname: str) -> LogSeverity: elif levelname == "INFO": return LogSeverity.INFO elif levelname == "WARNING": - return LogSeverity.WARNING + return LogSeverity.WARN elif levelname == "ERROR": return LogSeverity.ERROR elif levelname == "CRITICAL": diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml index e12f6e852..1422d6ee2 100644 --- a/tests/examples/local-run.yaml +++ b/tests/examples/local-run.yaml @@ -11,7 +11,12 @@ apis: - memory_banks - inference - safety +- evals providers: + evals: + - provider_id: meta-reference + provider_type: meta-reference + config: {} inference: - provider_id: meta-reference provider_type: meta-reference