diff --git a/llama_stack/distribution/registry/datasets/__init__.py b/llama_stack/distribution/registry/datasets/__init__.py index f0636212a..68de3fa87 100644 --- a/llama_stack/distribution/registry/datasets/__init__.py +++ b/llama_stack/distribution/registry/datasets/__init__.py @@ -6,8 +6,13 @@ # TODO: make these import config based from llama_stack.apis.dataset import * # noqa: F403 +from ..registry import Registry from .dataset import CustomDataset, HuggingfaceDataset -from .dataset_registry import DatasetRegistry + + +class DatasetRegistry(Registry[BaseDataset]): + _REGISTRY: Dict[str, BaseDataset] = {} + DATASETS_REGISTRY = [ CustomDataset( diff --git a/llama_stack/distribution/registry/datasets/dataset_registry.py b/llama_stack/distribution/registry/datasets/dataset_registry.py deleted file mode 100644 index 8e9b22266..000000000 --- a/llama_stack/distribution/registry/datasets/dataset_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import AbstractSet, Dict - -from llama_stack.apis.dataset import BaseDataset - - -class DatasetRegistry: - _REGISTRY: Dict[str, BaseDataset] = {} - - @staticmethod - def names() -> AbstractSet[str]: - return DatasetRegistry._REGISTRY.keys() - - @staticmethod - def register(name: str, task: BaseDataset) -> None: - if name in DatasetRegistry._REGISTRY: - raise ValueError(f"Dataset {name} already exists.") - DatasetRegistry._REGISTRY[name] = task - - @staticmethod - def get_dataset(name: str) -> BaseDataset: - if name not in DatasetRegistry._REGISTRY: - raise ValueError(f"Dataset {name} not found.") - return DatasetRegistry._REGISTRY[name] - - @staticmethod - def reset() -> None: - DatasetRegistry._REGISTRY = {} diff --git a/llama_stack/distribution/registry/generator_processors/__init__.py b/llama_stack/distribution/registry/generator_processors/__init__.py new file mode 100644 index 000000000..bb9d5c182 --- /dev/null +++ b/llama_stack/distribution/registry/generator_processors/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.evals import * # noqa: F403 + +from ..registry import Registry + + +class GeneratorProcessorRegistry(Registry[BaseGeneratorProcessor]): + _REGISTRY: Dict[str, BaseGeneratorProcessor] = {} diff --git a/llama_stack/distribution/registry/registry.py b/llama_stack/distribution/registry/registry.py new file mode 100644 index 000000000..b4a5b626d --- /dev/null +++ b/llama_stack/distribution/registry/registry.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import AbstractSet, Dict, Generic, TypeVar + +TRegistry = TypeVar("TRegistry") + + +class Registry(Generic[TRegistry]): + _REGISTRY: Dict[str, TRegistry] = {} + + @staticmethod + def names() -> AbstractSet[str]: + return Registry._REGISTRY.keys() + + @staticmethod + def register(name: str, task: TRegistry) -> None: + if name in Registry._REGISTRY: + raise ValueError(f"Dataset {name} already exists.") + Registry._REGISTRY[name] = task + + @staticmethod + def get(name: str) -> TRegistry: + if name not in Registry._REGISTRY: + raise ValueError(f"Dataset {name} not found.") + return Registry._REGISTRY[name] + + @staticmethod + def reset() -> None: + Registry._REGISTRY = {} diff --git a/llama_stack/distribution/registry/scorers/__init__.py b/llama_stack/distribution/registry/scorers/__init__.py index 76edd2ebd..3332b7052 100644 --- a/llama_stack/distribution/registry/scorers/__init__.py +++ b/llama_stack/distribution/registry/scorers/__init__.py @@ -4,3 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. # TODO: make these import config based +from llama_stack.apis.evals import * # noqa: F403 + +from ..registry import Registry + + +class ScorerRegistry(Registry[BaseScorer]): + _REGISTRY: Dict[str, BaseScorer] = {} diff --git a/llama_stack/distribution/registry/scorers/scorer_registry.py b/llama_stack/distribution/registry/scorers/scorer_registry.py deleted file mode 100644 index b6a382c53..000000000 --- a/llama_stack/distribution/registry/scorers/scorer_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import AbstractSet, Dict - -from llama_stack.apis.evals import BaseScorer - - -class ScorerRegistry: - _REGISTRY: Dict[str, BaseScorer] = {} - - @staticmethod - def names() -> AbstractSet[str]: - return ScorerRegistry._REGISTRY.keys() - - @staticmethod - def register(name: str, scorer: BaseScorer) -> None: - if name in ScorerRegistry._REGISTRY: - raise ValueError(f"Task {name} already exists.") - ScorerRegistry._REGISTRY[name] = task - - @staticmethod - def get_scorer(name: str) -> BaseScorer: - if name not in ScorerRegistry._REGISTRY: - raise ValueError(f"Task {name} not found.") - return ScorerRegistry._REGISTRY[name] - - @staticmethod - def reset() -> None: - ScorerRegistry._REGISTRY = {} diff --git a/llama_stack/distribution/registry/tasks/__init__.py b/llama_stack/distribution/registry/tasks/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/distribution/registry/tasks/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/distribution/registry/tasks/task_registry.py b/llama_stack/distribution/registry/tasks/task_registry.py deleted file mode 100644 index df25686ba..000000000 --- a/llama_stack/distribution/registry/tasks/task_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import AbstractSet, Dict - -from llama_stack.apis.evals import BaseTask - - -class TaskRegistry: - _REGISTRY: Dict[str, BaseTask] = {} - - @staticmethod - def names() -> AbstractSet[str]: - return TaskRegistry._REGISTRY.keys() - - @staticmethod - def register(name: str, task: BaseTask) -> None: - if name in TaskRegistry._REGISTRY: - raise ValueError(f"Task {name} already exists.") - TaskRegistry._REGISTRY[name] = task - - @staticmethod - def get_task(name: str) -> BaseTask: - if name not in TaskRegistry._REGISTRY: - raise ValueError(f"Task {name} not found.") - return TaskRegistry._REGISTRY[name] - - @staticmethod - def reset() -> None: - TaskRegistry._REGISTRY = {} diff --git a/llama_stack/providers/impls/meta_reference/evals/tasks/run_eval_task.py b/llama_stack/providers/impls/meta_reference/evals/tasks/run_eval_task.py index f3a66e18b..fde2efdb0 100644 --- a/llama_stack/providers/impls/meta_reference/evals/tasks/run_eval_task.py +++ b/llama_stack/providers/impls/meta_reference/evals/tasks/run_eval_task.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry +from llama_stack.distribution.registry.datasets import DatasetRegistry from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403 from llama_stack.providers.impls.meta_reference.evals.generator.inference_generator import ( InferenceGenerator, @@ -38,9 +38,8 @@ class RunEvalTask(BaseTask): ) -> EvalResult: print(f"Running eval task w/ {eval_task_config}") - dataset = DatasetRegistry.get_dataset( - eval_task_config.dataset_config.dataset_name - ) + print(DatasetRegistry.names()) + dataset = DatasetRegistry.get(eval_task_config.dataset_config.dataset_name) dataset.load(n_samples=eval_task_config.dataset_config.row_limit) print(f"Running on {len(dataset)} samples")