processor registry

This commit is contained in:
Xi Yan 2024-10-14 16:25:06 -07:00
parent 95fd53d292
commit a22c31b8a4
6 changed files with 33 additions and 10 deletions

View file

@ -47,6 +47,9 @@ class MetaReferenceEvalsImpl(Evals):
dataset_name=dataset,
row_limit=3,
),
processor_config=EvaluateProcessorConfig(
processor_identifier="mmlu",
),
generation_config=EvaluateModelGenerationConfig(
model=model,
),

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .mmlu_processor import MMLUProcessor # noqa: F401

View file

@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .basic_scorers import * # noqa: F401 F403
from .aggregate_scorer import * # noqa: F401 F403

View file

@ -4,15 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.registry.datasets import DatasetRegistry
from llama_stack.distribution.registry.generator_processors import (
GeneratorProcessorRegistry,
)
from llama_stack.distribution.registry.scorers import ScorerRegistry
from llama_stack.providers.impls.meta_reference.evals.scorer.aggregate_scorer import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.generator.inference_generator import (
InferenceGenerator,
)
from llama_stack.providers.impls.meta_reference.evals.processor.mmlu_processor import (
MMLUProcessor,
)
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@ -46,7 +48,10 @@ class RunEvalTask(BaseTask):
print(f"Running on {len(dataset)} samples")
# F1
processor = MMLUProcessor()
print(GeneratorProcessorRegistry.names())
processor = GeneratorProcessorRegistry.get(
eval_task_config.processor_config.processor_identifier
)()
preprocessed = processor.preprocess(dataset)
# Generation