tasks registry

This commit is contained in:
Xi Yan 2024-10-07 15:57:39 -07:00
parent 041634192a
commit 4764762dd4
9 changed files with 74 additions and 35 deletions

View file

@ -43,20 +43,20 @@ async def run_main(host: str, port: int):
client = EvaluationClient(f"http://{host}:{port}") client = EvaluationClient(f"http://{host}:{port}")
# CustomDataset # CustomDataset
# response = await client.run_evals(
# "Llama3.1-8B-Instruct",
# "mmlu-simple-eval-en",
# "mmlu",
# )
# cprint(f"evaluate response={response}", "green")
# Eleuther Eval
response = await client.run_evals( response = await client.run_evals(
"Llama3.1-8B-Instruct", "Llama3.1-8B-Instruct",
"PLACEHOLDER_DATASET_NAME", "mmlu-simple-eval-en",
"mmlu", "mmlu",
) )
cprint(response.metrics["metrics_table"], "red") cprint(f"evaluate response={response}", "green")
# Eleuther Eval
# response = await client.run_evals(
# "Llama3.1-8B-Instruct",
# "PLACEHOLDER_DATASET_NAME",
# "mmlu",
# )
# cprint(response.metrics["metrics_table"], "red")
def main(host: str, port: int): def main(host: str, port: int):

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# TODO: make these import config based
from llama_stack.providers.impls.meta_reference.evals.tasks.mmlu_task import MMLUTask
from .task_registry import TaskRegistry
TaskRegistry.register(
"mmlu",
MMLUTask,
)

View file

@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
class BaseTask(ABC): class BaseTask(ABC):
""" """
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
Base class for all evaluation tasks. Each task needs to implement the following methods: Base class for all evaluation tasks. Each task needs to implement the following methods:
- F1: preprocess_sample(self) - F1: preprocess_sample(self)
- F2: postprocess_sample(self) - F2: postprocess_sample(self)

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AbstractSet, Dict
from .task import BaseTask
class TaskRegistry:
_REGISTRY: Dict[str, BaseTask] = {}
@staticmethod
def names() -> AbstractSet[str]:
return TaskRegistry._REGISTRY.keys()
@staticmethod
def register(name: str, task: BaseTask) -> None:
if name in TaskRegistry._REGISTRY:
raise ValueError(f"Task {name} already exists.")
TaskRegistry._REGISTRY[name] = task
@staticmethod
def get_task(name: str) -> BaseTask:
if name not in TaskRegistry._REGISTRY:
raise ValueError(f"Task {name} not found.")
return TaskRegistry._REGISTRY[name]
@staticmethod
def reset() -> None:
TaskRegistry._REGISTRY = {}

View file

@ -8,12 +8,15 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import ( from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import (
get_dataset, get_dataset,
) )
from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import (
get_task, # from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import (
) # get_task,
# )
from .config import MetaReferenceEvalsImplConfig from .config import MetaReferenceEvalsImplConfig
@ -36,7 +39,8 @@ class MetaReferenceEvalsImpl(Evals):
) -> EvaluateResponse: ) -> EvaluateResponse:
cprint(f"model={model}, dataset={dataset}, task={task}", "red") cprint(f"model={model}, dataset={dataset}, task={task}", "red")
dataset = get_dataset(dataset) dataset = get_dataset(dataset)
task_impl = get_task(task, dataset) task_impl = TaskRegistry.get_task(task)(dataset)
x1 = task_impl.preprocess() x1 = task_impl.preprocess()
# TODO: replace w/ batch inference & async return eval job # TODO: replace w/ batch inference & async return eval job

View file

@ -5,8 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import re import re
from .task import BaseTask
from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403
from llama_stack.distribution.registry.tasks.task import BaseTask
QUERY_TEMPLATE_MULTICHOICE = """ QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .mmlu_task import MMLUTask
# TODO: make this into a config based registry
TASKS_REGISTRY = {
"mmlu": MMLUTask,
}
def get_task(task_id: str, dataset):
task_impl = TASKS_REGISTRY[task_id]
return task_impl(dataset)

View file

@ -12,12 +12,12 @@ apis_to_serve:
- safety - safety
- evals - evals
api_providers: api_providers:
evals:
provider_type: eleuther
config: {}
# evals: # evals:
# provider_type: meta-reference # provider_type: eleuther
# config: {} # config: {}
evals:
provider_type: meta-reference
config: {}
inference: inference:
providers: providers:
- meta-reference - meta-reference