mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
registry refactor
This commit is contained in:
parent
a56ea48d71
commit
b87bdd0176
5 changed files with 52 additions and 22 deletions
|
@ -3,9 +3,11 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .datasets import CustomDataset, HFDataset
|
||||
|
||||
# TODO: make this into a config based registry
|
||||
# TODO: make these import config based
|
||||
from .dataset import CustomDataset, HFDataset
|
||||
from .dataset_registry import DatasetRegistry
|
||||
|
||||
DATASETS_REGISTRY = {
|
||||
"mmlu-simple-eval-en": CustomDataset(
|
||||
name="mmlu_eval",
|
||||
|
@ -17,8 +19,5 @@ DATASETS_REGISTRY = {
|
|||
),
|
||||
}
|
||||
|
||||
|
||||
def get_dataset(dataset_id: str):
|
||||
dataset = DATASETS_REGISTRY[dataset_id]
|
||||
dataset.load()
|
||||
return dataset
|
||||
for k, v in DATASETS_REGISTRY.items():
|
||||
DatasetRegistry.register(k, v)
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import AbstractSet, Dict
|
||||
|
||||
from .dataset import BaseDataset
|
||||
|
||||
|
||||
class DatasetRegistry:
|
||||
_REGISTRY: Dict[str, BaseDataset] = {}
|
||||
|
||||
@staticmethod
|
||||
def names() -> AbstractSet[str]:
|
||||
return DatasetRegistry._REGISTRY.keys()
|
||||
|
||||
@staticmethod
|
||||
def register(name: str, task: BaseDataset) -> None:
|
||||
if name in DatasetRegistry._REGISTRY:
|
||||
raise ValueError(f"Dataset {name} already exists.")
|
||||
DatasetRegistry._REGISTRY[name] = task
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(name: str) -> BaseDataset:
|
||||
if name not in DatasetRegistry._REGISTRY:
|
||||
raise ValueError(f"Dataset {name} not found.")
|
||||
return DatasetRegistry._REGISTRY[name]
|
||||
|
||||
@staticmethod
|
||||
def reset() -> None:
|
||||
DatasetRegistry._REGISTRY = {}
|
|
@ -8,11 +8,9 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
||||
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry import (
|
||||
get_dataset,
|
||||
)
|
||||
from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
||||
|
||||
from .config import MetaReferenceEvalsImplConfig
|
||||
|
||||
|
@ -34,7 +32,8 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
task: str,
|
||||
) -> EvaluateResponse:
|
||||
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
||||
dataset = get_dataset(dataset)
|
||||
dataset = DatasetRegistry.get_dataset(dataset)
|
||||
dataset.load()
|
||||
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||
|
||||
x1 = task_impl.preprocess()
|
||||
|
|
|
@ -41,20 +41,20 @@ routing_table:
|
|||
inference:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
model: Llama3.2-1B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
routing_key: Llama-Guard-3-1B
|
||||
routing_key: Llama3.2-1B-Instruct
|
||||
# - provider_type: meta-reference
|
||||
# config:
|
||||
# model: Llama-Guard-3-1B
|
||||
# quantization: null
|
||||
# torch_seed: null
|
||||
# max_seq_len: 4096
|
||||
# max_batch_size: 1
|
||||
# routing_key: Llama-Guard-3-1B
|
||||
# - provider_type: remote::tgi
|
||||
# config:
|
||||
# url: http://127.0.0.1:5009
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue