mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
tasks registry
This commit is contained in:
parent
041634192a
commit
4764762dd4
9 changed files with 74 additions and 35 deletions
|
@ -43,20 +43,20 @@ async def run_main(host: str, port: int):
|
|||
client = EvaluationClient(f"http://{host}:{port}")
|
||||
|
||||
# CustomDataset
|
||||
# response = await client.run_evals(
|
||||
# "Llama3.1-8B-Instruct",
|
||||
# "mmlu-simple-eval-en",
|
||||
# "mmlu",
|
||||
# )
|
||||
# cprint(f"evaluate response={response}", "green")
|
||||
|
||||
# Eleuther Eval
|
||||
response = await client.run_evals(
|
||||
"Llama3.1-8B-Instruct",
|
||||
"PLACEHOLDER_DATASET_NAME",
|
||||
"mmlu-simple-eval-en",
|
||||
"mmlu",
|
||||
)
|
||||
cprint(response.metrics["metrics_table"], "red")
|
||||
cprint(f"evaluate response={response}", "green")
|
||||
|
||||
# Eleuther Eval
|
||||
# response = await client.run_evals(
|
||||
# "Llama3.1-8B-Instruct",
|
||||
# "PLACEHOLDER_DATASET_NAME",
|
||||
# "mmlu",
|
||||
# )
|
||||
# cprint(response.metrics["metrics_table"], "red")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
|
|
5
llama_stack/distribution/registry/__init__.py
Normal file
5
llama_stack/distribution/registry/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
13
llama_stack/distribution/registry/tasks/__init__.py
Normal file
13
llama_stack/distribution/registry/tasks/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# TODO: make these import config based
|
||||
from llama_stack.providers.impls.meta_reference.evals.tasks.mmlu_task import MMLUTask
|
||||
from .task_registry import TaskRegistry
|
||||
|
||||
TaskRegistry.register(
|
||||
"mmlu",
|
||||
MMLUTask,
|
||||
)
|
|
@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
|
|||
|
||||
class BaseTask(ABC):
|
||||
"""
|
||||
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
|
||||
Base class for all evaluation tasks. Each task needs to implement the following methods:
|
||||
- F1: preprocess_sample(self)
|
||||
- F2: postprocess_sample(self)
|
32
llama_stack/distribution/registry/tasks/task_registry.py
Normal file
32
llama_stack/distribution/registry/tasks/task_registry.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import AbstractSet, Dict
|
||||
|
||||
from .task import BaseTask
|
||||
|
||||
|
||||
class TaskRegistry:
|
||||
_REGISTRY: Dict[str, BaseTask] = {}
|
||||
|
||||
@staticmethod
|
||||
def names() -> AbstractSet[str]:
|
||||
return TaskRegistry._REGISTRY.keys()
|
||||
|
||||
@staticmethod
|
||||
def register(name: str, task: BaseTask) -> None:
|
||||
if name in TaskRegistry._REGISTRY:
|
||||
raise ValueError(f"Task {name} already exists.")
|
||||
TaskRegistry._REGISTRY[name] = task
|
||||
|
||||
@staticmethod
|
||||
def get_task(name: str) -> BaseTask:
|
||||
if name not in TaskRegistry._REGISTRY:
|
||||
raise ValueError(f"Task {name} not found.")
|
||||
return TaskRegistry._REGISTRY[name]
|
||||
|
||||
@staticmethod
|
||||
def reset() -> None:
|
||||
TaskRegistry._REGISTRY = {}
|
|
@ -8,12 +8,15 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import (
|
||||
get_dataset,
|
||||
)
|
||||
from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import (
|
||||
get_task,
|
||||
)
|
||||
|
||||
# from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import (
|
||||
# get_task,
|
||||
# )
|
||||
|
||||
from .config import MetaReferenceEvalsImplConfig
|
||||
|
||||
|
@ -36,7 +39,8 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
) -> EvaluateResponse:
|
||||
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
||||
dataset = get_dataset(dataset)
|
||||
task_impl = get_task(task, dataset)
|
||||
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||
|
||||
x1 = task_impl.preprocess()
|
||||
|
||||
# TODO: replace w/ batch inference & async return eval job
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
import re
|
||||
|
||||
from .task import BaseTask
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.distribution.registry.tasks.task import BaseTask
|
||||
|
||||
QUERY_TEMPLATE_MULTICHOICE = """
|
||||
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .mmlu_task import MMLUTask
|
||||
|
||||
# TODO: make this into a config based registry
|
||||
TASKS_REGISTRY = {
|
||||
"mmlu": MMLUTask,
|
||||
}
|
||||
|
||||
|
||||
def get_task(task_id: str, dataset):
|
||||
task_impl = TASKS_REGISTRY[task_id]
|
||||
return task_impl(dataset)
|
|
@ -12,12 +12,12 @@ apis_to_serve:
|
|||
- safety
|
||||
- evals
|
||||
api_providers:
|
||||
evals:
|
||||
provider_type: eleuther
|
||||
config: {}
|
||||
# evals:
|
||||
# provider_type: meta-reference
|
||||
# provider_type: eleuther
|
||||
# config: {}
|
||||
evals:
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
inference:
|
||||
providers:
|
||||
- meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue