registry refactor

This commit is contained in:
Xi Yan 2024-10-08 15:44:02 -07:00
parent a56ea48d71
commit b87bdd0176
5 changed files with 52 additions and 22 deletions

View file

@ -3,9 +3,11 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasets import CustomDataset, HFDataset
# TODO: make this into a config based registry
# TODO: make these import config based
from .dataset import CustomDataset, HFDataset
from .dataset_registry import DatasetRegistry
DATASETS_REGISTRY = {
"mmlu-simple-eval-en": CustomDataset(
name="mmlu_eval",
@ -17,8 +19,5 @@ DATASETS_REGISTRY = {
),
}
def get_dataset(dataset_id: str):
dataset = DATASETS_REGISTRY[dataset_id]
dataset.load()
return dataset
for k, v in DATASETS_REGISTRY.items():
DatasetRegistry.register(k, v)

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AbstractSet, Dict
from .dataset import BaseDataset
class DatasetRegistry:
_REGISTRY: Dict[str, BaseDataset] = {}
@staticmethod
def names() -> AbstractSet[str]:
return DatasetRegistry._REGISTRY.keys()
@staticmethod
def register(name: str, task: BaseDataset) -> None:
if name in DatasetRegistry._REGISTRY:
raise ValueError(f"Dataset {name} already exists.")
DatasetRegistry._REGISTRY[name] = task
@staticmethod
def get_dataset(name: str) -> BaseDataset:
if name not in DatasetRegistry._REGISTRY:
raise ValueError(f"Dataset {name} not found.")
return DatasetRegistry._REGISTRY[name]
@staticmethod
def reset() -> None:
DatasetRegistry._REGISTRY = {}

View file

@ -8,11 +8,9 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import (
get_dataset,
)
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
from .config import MetaReferenceEvalsImplConfig
@ -34,7 +32,8 @@ class MetaReferenceEvalsImpl(Evals):
task: str,
) -> EvaluateResponse:
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
dataset = get_dataset(dataset)
dataset = DatasetRegistry.get_dataset(dataset)
dataset.load()
task_impl = TaskRegistry.get_task(task)(dataset)
x1 = task_impl.preprocess()

View file

@ -41,20 +41,20 @@ routing_table:
inference:
- provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
model: Llama3.2-1B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
routing_key: Llama3.1-8B-Instruct
- provider_type: meta-reference
config:
model: Llama-Guard-3-1B
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
routing_key: Llama-Guard-3-1B
routing_key: Llama3.2-1B-Instruct
# - provider_type: meta-reference
# config:
# model: Llama-Guard-3-1B
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
# routing_key: Llama-Guard-3-1B
# - provider_type: remote::tgi
# config:
# url: http://127.0.0.1:5009