mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
tasks registry
This commit is contained in:
parent
041634192a
commit
4764762dd4
9 changed files with 74 additions and 35 deletions
|
@ -43,20 +43,20 @@ async def run_main(host: str, port: int):
|
||||||
client = EvaluationClient(f"http://{host}:{port}")
|
client = EvaluationClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
# CustomDataset
|
# CustomDataset
|
||||||
# response = await client.run_evals(
|
|
||||||
# "Llama3.1-8B-Instruct",
|
|
||||||
# "mmlu-simple-eval-en",
|
|
||||||
# "mmlu",
|
|
||||||
# )
|
|
||||||
# cprint(f"evaluate response={response}", "green")
|
|
||||||
|
|
||||||
# Eleuther Eval
|
|
||||||
response = await client.run_evals(
|
response = await client.run_evals(
|
||||||
"Llama3.1-8B-Instruct",
|
"Llama3.1-8B-Instruct",
|
||||||
"PLACEHOLDER_DATASET_NAME",
|
"mmlu-simple-eval-en",
|
||||||
"mmlu",
|
"mmlu",
|
||||||
)
|
)
|
||||||
cprint(response.metrics["metrics_table"], "red")
|
cprint(f"evaluate response={response}", "green")
|
||||||
|
|
||||||
|
# Eleuther Eval
|
||||||
|
# response = await client.run_evals(
|
||||||
|
# "Llama3.1-8B-Instruct",
|
||||||
|
# "PLACEHOLDER_DATASET_NAME",
|
||||||
|
# "mmlu",
|
||||||
|
# )
|
||||||
|
# cprint(response.metrics["metrics_table"], "red")
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int):
|
||||||
|
|
5
llama_stack/distribution/registry/__init__.py
Normal file
5
llama_stack/distribution/registry/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
13
llama_stack/distribution/registry/tasks/__init__.py
Normal file
13
llama_stack/distribution/registry/tasks/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
# TODO: make these import config based
|
||||||
|
from llama_stack.providers.impls.meta_reference.evals.tasks.mmlu_task import MMLUTask
|
||||||
|
from .task_registry import TaskRegistry
|
||||||
|
|
||||||
|
TaskRegistry.register(
|
||||||
|
"mmlu",
|
||||||
|
MMLUTask,
|
||||||
|
)
|
|
@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
class BaseTask(ABC):
|
class BaseTask(ABC):
|
||||||
"""
|
"""
|
||||||
|
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
|
||||||
Base class for all evaluation tasks. Each task needs to implement the following methods:
|
Base class for all evaluation tasks. Each task needs to implement the following methods:
|
||||||
- F1: preprocess_sample(self)
|
- F1: preprocess_sample(self)
|
||||||
- F2: postprocess_sample(self)
|
- F2: postprocess_sample(self)
|
32
llama_stack/distribution/registry/tasks/task_registry.py
Normal file
32
llama_stack/distribution/registry/tasks/task_registry.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import AbstractSet, Dict
|
||||||
|
|
||||||
|
from .task import BaseTask
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRegistry:
|
||||||
|
_REGISTRY: Dict[str, BaseTask] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def names() -> AbstractSet[str]:
|
||||||
|
return TaskRegistry._REGISTRY.keys()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register(name: str, task: BaseTask) -> None:
|
||||||
|
if name in TaskRegistry._REGISTRY:
|
||||||
|
raise ValueError(f"Task {name} already exists.")
|
||||||
|
TaskRegistry._REGISTRY[name] = task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_task(name: str) -> BaseTask:
|
||||||
|
if name not in TaskRegistry._REGISTRY:
|
||||||
|
raise ValueError(f"Task {name} not found.")
|
||||||
|
return TaskRegistry._REGISTRY[name]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset() -> None:
|
||||||
|
TaskRegistry._REGISTRY = {}
|
|
@ -8,12 +8,15 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import (
|
from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import (
|
||||||
get_dataset,
|
get_dataset,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import (
|
|
||||||
get_task,
|
# from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import (
|
||||||
)
|
# get_task,
|
||||||
|
# )
|
||||||
|
|
||||||
from .config import MetaReferenceEvalsImplConfig
|
from .config import MetaReferenceEvalsImplConfig
|
||||||
|
|
||||||
|
@ -36,7 +39,8 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
||||||
dataset = get_dataset(dataset)
|
dataset = get_dataset(dataset)
|
||||||
task_impl = get_task(task, dataset)
|
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||||
|
|
||||||
x1 = task_impl.preprocess()
|
x1 = task_impl.preprocess()
|
||||||
|
|
||||||
# TODO: replace w/ batch inference & async return eval job
|
# TODO: replace w/ batch inference & async return eval job
|
||||||
|
|
|
@ -5,8 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from .task import BaseTask
|
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
|
from llama_stack.distribution.registry.tasks.task import BaseTask
|
||||||
|
|
||||||
QUERY_TEMPLATE_MULTICHOICE = """
|
QUERY_TEMPLATE_MULTICHOICE = """
|
||||||
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
from .mmlu_task import MMLUTask
|
|
||||||
|
|
||||||
# TODO: make this into a config based registry
|
|
||||||
TASKS_REGISTRY = {
|
|
||||||
"mmlu": MMLUTask,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_task(task_id: str, dataset):
|
|
||||||
task_impl = TASKS_REGISTRY[task_id]
|
|
||||||
return task_impl(dataset)
|
|
|
@ -12,12 +12,12 @@ apis_to_serve:
|
||||||
- safety
|
- safety
|
||||||
- evals
|
- evals
|
||||||
api_providers:
|
api_providers:
|
||||||
evals:
|
|
||||||
provider_type: eleuther
|
|
||||||
config: {}
|
|
||||||
# evals:
|
# evals:
|
||||||
# provider_type: meta-reference
|
# provider_type: eleuther
|
||||||
# config: {}
|
# config: {}
|
||||||
|
evals:
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
inference:
|
inference:
|
||||||
providers:
|
providers:
|
||||||
- meta-reference
|
- meta-reference
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue