From b87bdd0176df663bc4bdeae0ed3fbed32bf81d20 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 8 Oct 2024 15:44:02 -0700 Subject: [PATCH] registry refactor --- .../registry/datasets/__init__.py} | 13 ++++---- .../registry/datasets/dataset.py} | 0 .../registry/datasets/dataset_registry.py | 32 +++++++++++++++++++ .../impls/meta_reference/evals/evals.py | 9 +++--- tests/examples/local-run.yaml | 20 ++++++------ 5 files changed, 52 insertions(+), 22 deletions(-) rename llama_stack/{providers/impls/meta_reference/evals/datas/dataset_registry.py => distribution/registry/datasets/__init__.py} (71%) rename llama_stack/{providers/impls/meta_reference/evals/datas/datasets.py => distribution/registry/datasets/dataset.py} (100%) create mode 100644 llama_stack/distribution/registry/datasets/dataset_registry.py diff --git a/llama_stack/providers/impls/meta_reference/evals/datas/dataset_registry.py b/llama_stack/distribution/registry/datasets/__init__.py similarity index 71% rename from llama_stack/providers/impls/meta_reference/evals/datas/dataset_registry.py rename to llama_stack/distribution/registry/datasets/__init__.py index ad1c5f372..ffa000767 100644 --- a/llama_stack/providers/impls/meta_reference/evals/datas/dataset_registry.py +++ b/llama_stack/distribution/registry/datasets/__init__.py @@ -3,9 +3,11 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datasets import CustomDataset, HFDataset -# TODO: make this into a config based registry +# TODO: make these import config based +from .dataset import CustomDataset, HFDataset +from .dataset_registry import DatasetRegistry + DATASETS_REGISTRY = { "mmlu-simple-eval-en": CustomDataset( name="mmlu_eval", @@ -17,8 +19,5 @@ DATASETS_REGISTRY = { ), } - -def get_dataset(dataset_id: str): - dataset = DATASETS_REGISTRY[dataset_id] - dataset.load() - return dataset +for k, v in DATASETS_REGISTRY.items(): + DatasetRegistry.register(k, v) diff --git a/llama_stack/providers/impls/meta_reference/evals/datas/datasets.py b/llama_stack/distribution/registry/datasets/dataset.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/evals/datas/datasets.py rename to llama_stack/distribution/registry/datasets/dataset.py diff --git a/llama_stack/distribution/registry/datasets/dataset_registry.py b/llama_stack/distribution/registry/datasets/dataset_registry.py new file mode 100644 index 000000000..9ddaa8bb7 --- /dev/null +++ b/llama_stack/distribution/registry/datasets/dataset_registry.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import AbstractSet, Dict + +from .dataset import BaseDataset + + +class DatasetRegistry: + _REGISTRY: Dict[str, BaseDataset] = {} + + @staticmethod + def names() -> AbstractSet[str]: + return DatasetRegistry._REGISTRY.keys() + + @staticmethod + def register(name: str, task: BaseDataset) -> None: + if name in DatasetRegistry._REGISTRY: + raise ValueError(f"Dataset {name} already exists.") + DatasetRegistry._REGISTRY[name] = task + + @staticmethod + def get_dataset(name: str) -> BaseDataset: + if name not in DatasetRegistry._REGISTRY: + raise ValueError(f"Dataset {name} not found.") + return DatasetRegistry._REGISTRY[name] + + @staticmethod + def reset() -> None: + DatasetRegistry._REGISTRY = {} diff --git a/llama_stack/providers/impls/meta_reference/evals/evals.py b/llama_stack/providers/impls/meta_reference/evals/evals.py index 3ceb51aab..1ed7975eb 100644 --- a/llama_stack/providers/impls/meta_reference/evals/evals.py +++ b/llama_stack/providers/impls/meta_reference/evals/evals.py @@ -8,11 +8,9 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403 from termcolor import cprint -from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry +from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry -from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import ( - get_dataset, -) +from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry from .config import MetaReferenceEvalsImplConfig @@ -34,7 +32,8 @@ class MetaReferenceEvalsImpl(Evals): task: str, ) -> EvaluateResponse: cprint(f"model={model}, dataset={dataset}, task={task}", "red") - dataset = get_dataset(dataset) + dataset = DatasetRegistry.get_dataset(dataset) + dataset.load() task_impl = TaskRegistry.get_task(task)(dataset) x1 = task_impl.preprocess() diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml index d24a9906c..78b1b32d7 100644 --- a/tests/examples/local-run.yaml +++ b/tests/examples/local-run.yaml @@ -41,20 +41,20 @@ routing_table: inference: - provider_type: meta-reference config: - model: Llama3.1-8B-Instruct + model: Llama3.2-1B-Instruct quantization: null torch_seed: null max_seq_len: 4096 max_batch_size: 1 - routing_key: Llama3.1-8B-Instruct - - provider_type: meta-reference - config: - model: Llama-Guard-3-1B - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - routing_key: Llama-Guard-3-1B + routing_key: Llama3.2-1B-Instruct + # - provider_type: meta-reference + # config: + # model: Llama-Guard-3-1B + # quantization: null + # torch_seed: null + # max_seq_len: 4096 + # max_batch_size: 1 + # routing_key: Llama-Guard-3-1B # - provider_type: remote::tgi # config: # url: http://127.0.0.1:5009