mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
registry refactor
This commit is contained in:
parent
a56ea48d71
commit
b87bdd0176
5 changed files with 52 additions and 22 deletions
|
@ -3,9 +3,11 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from .datasets import CustomDataset, HFDataset
|
|
||||||
|
|
||||||
# TODO: make this into a config based registry
|
# TODO: make these import config based
|
||||||
|
from .dataset import CustomDataset, HFDataset
|
||||||
|
from .dataset_registry import DatasetRegistry
|
||||||
|
|
||||||
DATASETS_REGISTRY = {
|
DATASETS_REGISTRY = {
|
||||||
"mmlu-simple-eval-en": CustomDataset(
|
"mmlu-simple-eval-en": CustomDataset(
|
||||||
name="mmlu_eval",
|
name="mmlu_eval",
|
||||||
|
@ -17,8 +19,5 @@ DATASETS_REGISTRY = {
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for k, v in DATASETS_REGISTRY.items():
|
||||||
def get_dataset(dataset_id: str):
|
DatasetRegistry.register(k, v)
|
||||||
dataset = DATASETS_REGISTRY[dataset_id]
|
|
||||||
dataset.load()
|
|
||||||
return dataset
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import AbstractSet, Dict
|
||||||
|
|
||||||
|
from .dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetRegistry:
|
||||||
|
_REGISTRY: Dict[str, BaseDataset] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def names() -> AbstractSet[str]:
|
||||||
|
return DatasetRegistry._REGISTRY.keys()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register(name: str, task: BaseDataset) -> None:
|
||||||
|
if name in DatasetRegistry._REGISTRY:
|
||||||
|
raise ValueError(f"Dataset {name} already exists.")
|
||||||
|
DatasetRegistry._REGISTRY[name] = task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dataset(name: str) -> BaseDataset:
|
||||||
|
if name not in DatasetRegistry._REGISTRY:
|
||||||
|
raise ValueError(f"Dataset {name} not found.")
|
||||||
|
return DatasetRegistry._REGISTRY[name]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset() -> None:
|
||||||
|
DatasetRegistry._REGISTRY = {}
|
|
@ -8,11 +8,9 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import (
|
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
||||||
get_dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import MetaReferenceEvalsImplConfig
|
from .config import MetaReferenceEvalsImplConfig
|
||||||
|
|
||||||
|
@ -34,7 +32,8 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
task: str,
|
task: str,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
||||||
dataset = get_dataset(dataset)
|
dataset = DatasetRegistry.get_dataset(dataset)
|
||||||
|
dataset.load()
|
||||||
task_impl = TaskRegistry.get_task(task)(dataset)
|
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||||
|
|
||||||
x1 = task_impl.preprocess()
|
x1 = task_impl.preprocess()
|
||||||
|
|
|
@ -41,20 +41,20 @@ routing_table:
|
||||||
inference:
|
inference:
|
||||||
- provider_type: meta-reference
|
- provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
model: Llama3.1-8B-Instruct
|
model: Llama3.2-1B-Instruct
|
||||||
quantization: null
|
quantization: null
|
||||||
torch_seed: null
|
torch_seed: null
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
max_batch_size: 1
|
max_batch_size: 1
|
||||||
routing_key: Llama3.1-8B-Instruct
|
routing_key: Llama3.2-1B-Instruct
|
||||||
- provider_type: meta-reference
|
# - provider_type: meta-reference
|
||||||
config:
|
# config:
|
||||||
model: Llama-Guard-3-1B
|
# model: Llama-Guard-3-1B
|
||||||
quantization: null
|
# quantization: null
|
||||||
torch_seed: null
|
# torch_seed: null
|
||||||
max_seq_len: 4096
|
# max_seq_len: 4096
|
||||||
max_batch_size: 1
|
# max_batch_size: 1
|
||||||
routing_key: Llama-Guard-3-1B
|
# routing_key: Llama-Guard-3-1B
|
||||||
# - provider_type: remote::tgi
|
# - provider_type: remote::tgi
|
||||||
# config:
|
# config:
|
||||||
# url: http://127.0.0.1:5009
|
# url: http://127.0.0.1:5009
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue