tasks registry

This commit is contained in:
Xi Yan 2024-10-07 15:57:39 -07:00
parent 041634192a
commit 4764762dd4
9 changed files with 74 additions and 35 deletions

View file

@ -43,20 +43,20 @@ async def run_main(host: str, port: int):
client = EvaluationClient(f"http://{host}:{port}")
# CustomDataset
# response = await client.run_evals(
# "Llama3.1-8B-Instruct",
# "mmlu-simple-eval-en",
# "mmlu",
# )
# cprint(f"evaluate response={response}", "green")
# Eleuther Eval
response = await client.run_evals(
"Llama3.1-8B-Instruct",
"PLACEHOLDER_DATASET_NAME",
"mmlu-simple-eval-en",
"mmlu",
)
cprint(response.metrics["metrics_table"], "red")
cprint(f"evaluate response={response}", "green")
# Eleuther Eval
# response = await client.run_evals(
# "Llama3.1-8B-Instruct",
# "PLACEHOLDER_DATASET_NAME",
# "mmlu",
# )
# cprint(response.metrics["metrics_table"], "red")
def main(host: str, port: int):

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# TODO: make these import config based
from llama_stack.providers.impls.meta_reference.evals.tasks.mmlu_task import MMLUTask
from .task_registry import TaskRegistry
TaskRegistry.register(
"mmlu",
MMLUTask,
)

View file

@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
class BaseTask(ABC):
"""
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
Base class for all evaluation tasks. Each task needs to implement the following methods:
- F1: preprocess_sample(self)
- F2: postprocess_sample(self)

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AbstractSet, Dict
from .task import BaseTask
class TaskRegistry:
_REGISTRY: Dict[str, BaseTask] = {}
@staticmethod
def names() -> AbstractSet[str]:
return TaskRegistry._REGISTRY.keys()
@staticmethod
def register(name: str, task: BaseTask) -> None:
if name in TaskRegistry._REGISTRY:
raise ValueError(f"Task {name} already exists.")
TaskRegistry._REGISTRY[name] = task
@staticmethod
def get_task(name: str) -> BaseTask:
if name not in TaskRegistry._REGISTRY:
raise ValueError(f"Task {name} not found.")
return TaskRegistry._REGISTRY[name]
@staticmethod
def reset() -> None:
TaskRegistry._REGISTRY = {}

View file

@ -8,12 +8,15 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import (
get_dataset,
)
from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import (
get_task,
)
# from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import (
# get_task,
# )
from .config import MetaReferenceEvalsImplConfig
@ -36,7 +39,8 @@ class MetaReferenceEvalsImpl(Evals):
) -> EvaluateResponse:
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
dataset = get_dataset(dataset)
task_impl = get_task(task, dataset)
task_impl = TaskRegistry.get_task(task)(dataset)
x1 = task_impl.preprocess()
# TODO: replace w/ batch inference & async return eval job

View file

@ -5,8 +5,8 @@
# the root directory of this source tree.
import re
from .task import BaseTask
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.distribution.registry.tasks.task import BaseTask
QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .mmlu_task import MMLUTask
# TODO: make this into a config based registry
TASKS_REGISTRY = {
"mmlu": MMLUTask,
}
def get_task(task_id: str, dataset):
task_impl = TASKS_REGISTRY[task_id]
return task_impl(dataset)

View file

@ -12,12 +12,12 @@ apis_to_serve:
- safety
- evals
api_providers:
evals:
provider_type: eleuther
config: {}
# evals:
# provider_type: meta-reference
# provider_type: eleuther
# config: {}
evals:
provider_type: meta-reference
config: {}
inference:
providers:
- meta-reference