From 4764762dd4dbdd20141d176f1af1523c5f43b66e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 7 Oct 2024 15:57:39 -0700 Subject: [PATCH] tasks registry --- llama_stack/apis/evals/client.py | 20 ++++++------ llama_stack/distribution/registry/__init__.py | 5 +++ .../distribution/registry/tasks/__init__.py | 13 ++++++++ .../registry}/tasks/task.py | 1 + .../registry/tasks/task_registry.py | 32 +++++++++++++++++++ .../impls/meta_reference/evals/evals.py | 12 ++++--- .../meta_reference/evals/tasks/mmlu_task.py | 2 +- .../evals/tasks/task_registry.py | 16 ---------- tests/examples/local-run.yaml | 8 ++--- 9 files changed, 74 insertions(+), 35 deletions(-) create mode 100644 llama_stack/distribution/registry/__init__.py create mode 100644 llama_stack/distribution/registry/tasks/__init__.py rename llama_stack/{providers/impls/meta_reference/evals => distribution/registry}/tasks/task.py (91%) create mode 100644 llama_stack/distribution/registry/tasks/task_registry.py delete mode 100644 llama_stack/providers/impls/meta_reference/evals/tasks/task_registry.py diff --git a/llama_stack/apis/evals/client.py b/llama_stack/apis/evals/client.py index 2a6947b32..4acbff5f6 100644 --- a/llama_stack/apis/evals/client.py +++ b/llama_stack/apis/evals/client.py @@ -43,20 +43,20 @@ async def run_main(host: str, port: int): client = EvaluationClient(f"http://{host}:{port}") # CustomDataset - # response = await client.run_evals( - # "Llama3.1-8B-Instruct", - # "mmlu-simple-eval-en", - # "mmlu", - # ) - # cprint(f"evaluate response={response}", "green") - - # Eleuther Eval response = await client.run_evals( "Llama3.1-8B-Instruct", - "PLACEHOLDER_DATASET_NAME", + "mmlu-simple-eval-en", "mmlu", ) - cprint(response.metrics["metrics_table"], "red") + cprint(f"evaluate response={response}", "green") + + # Eleuther Eval + # response = await client.run_evals( + # "Llama3.1-8B-Instruct", + # "PLACEHOLDER_DATASET_NAME", + # "mmlu", + # ) + # cprint(response.metrics["metrics_table"], "red") def main(host: str, port: int): diff --git a/llama_stack/distribution/registry/__init__.py b/llama_stack/distribution/registry/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/registry/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/registry/tasks/__init__.py b/llama_stack/distribution/registry/tasks/__init__.py new file mode 100644 index 000000000..01ccb18ae --- /dev/null +++ b/llama_stack/distribution/registry/tasks/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +# TODO: make these import config based +from llama_stack.providers.impls.meta_reference.evals.tasks.mmlu_task import MMLUTask +from .task_registry import TaskRegistry + +TaskRegistry.register( + "mmlu", + MMLUTask, +) diff --git a/llama_stack/providers/impls/meta_reference/evals/tasks/task.py b/llama_stack/distribution/registry/tasks/task.py similarity index 91% rename from llama_stack/providers/impls/meta_reference/evals/tasks/task.py rename to llama_stack/distribution/registry/tasks/task.py index a5461165b..a92e6241b 100644 --- a/llama_stack/providers/impls/meta_reference/evals/tasks/task.py +++ b/llama_stack/distribution/registry/tasks/task.py @@ -8,6 +8,7 @@ from abc import ABC, abstractmethod class BaseTask(ABC): """ + A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods. Base class for all evaluation tasks. Each task needs to implement the following methods: - F1: preprocess_sample(self) - F2: postprocess_sample(self) diff --git a/llama_stack/distribution/registry/tasks/task_registry.py b/llama_stack/distribution/registry/tasks/task_registry.py new file mode 100644 index 000000000..063894e48 --- /dev/null +++ b/llama_stack/distribution/registry/tasks/task_registry.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import AbstractSet, Dict + +from .task import BaseTask + + +class TaskRegistry: + _REGISTRY: Dict[str, BaseTask] = {} + + @staticmethod + def names() -> AbstractSet[str]: + return TaskRegistry._REGISTRY.keys() + + @staticmethod + def register(name: str, task: BaseTask) -> None: + if name in TaskRegistry._REGISTRY: + raise ValueError(f"Task {name} already exists.") + TaskRegistry._REGISTRY[name] = task + + @staticmethod + def get_task(name: str) -> BaseTask: + if name not in TaskRegistry._REGISTRY: + raise ValueError(f"Task {name} not found.") + return TaskRegistry._REGISTRY[name] + + @staticmethod + def reset() -> None: + TaskRegistry._REGISTRY = {} diff --git a/llama_stack/providers/impls/meta_reference/evals/evals.py b/llama_stack/providers/impls/meta_reference/evals/evals.py index a52615ed8..82c2a542b 100644 --- a/llama_stack/providers/impls/meta_reference/evals/evals.py +++ b/llama_stack/providers/impls/meta_reference/evals/evals.py @@ -8,12 +8,15 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403 from termcolor import cprint +from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry + from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import ( get_dataset, ) -from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import ( - get_task, -) + +# from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import ( +# get_task, +# ) from .config import MetaReferenceEvalsImplConfig @@ -36,7 +39,8 @@ class MetaReferenceEvalsImpl(Evals): ) -> EvaluateResponse: cprint(f"model={model}, dataset={dataset}, task={task}", "red") dataset = get_dataset(dataset) - task_impl = get_task(task, dataset) + task_impl = TaskRegistry.get_task(task)(dataset) + x1 = task_impl.preprocess() # TODO: replace w/ batch inference & async return eval job diff --git a/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py b/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py index d74476628..673a95379 100644 --- a/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py +++ b/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py @@ -5,8 +5,8 @@ # the root directory of this source tree. import re -from .task import BaseTask from llama_stack.apis.evals import * # noqa: F403 +from llama_stack.distribution.registry.tasks.task import BaseTask QUERY_TEMPLATE_MULTICHOICE = """ Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. diff --git a/llama_stack/providers/impls/meta_reference/evals/tasks/task_registry.py b/llama_stack/providers/impls/meta_reference/evals/tasks/task_registry.py deleted file mode 100644 index 69720051d..000000000 --- a/llama_stack/providers/impls/meta_reference/evals/tasks/task_registry.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from .mmlu_task import MMLUTask - -# TODO: make this into a config based registry -TASKS_REGISTRY = { - "mmlu": MMLUTask, -} - - -def get_task(task_id: str, dataset): - task_impl = TASKS_REGISTRY[task_id] - return task_impl(dataset) diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml index 4a616bc88..f7e3b1655 100644 --- a/tests/examples/local-run.yaml +++ b/tests/examples/local-run.yaml @@ -12,12 +12,12 @@ apis_to_serve: - safety - evals api_providers: - evals: - provider_type: eleuther - config: {} # evals: - # provider_type: meta-reference + # provider_type: eleuther # config: {} + evals: + provider_type: meta-reference + config: {} inference: providers: - meta-reference