mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
RunEvalTask / InferenceGenerator
This commit is contained in:
parent
8890de7322
commit
78cb88c3c4
5 changed files with 111 additions and 84 deletions
|
@ -171,7 +171,7 @@ class BaseGeneratorProcessor(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class BaseGenerator(ABC, Generic[TGenerationResponseSample]):
|
class BaseGenerator(ABC, Generic[TPreprocessedSample, TGenerationResponseSample]):
|
||||||
"""
|
"""
|
||||||
Base class for all generators. Each generator needs to implement the following methods:
|
Base class for all generators. Each generator needs to implement the following methods:
|
||||||
- generate(self, preprocessed_dataset)
|
- generate(self, preprocessed_dataset)
|
||||||
|
@ -184,7 +184,7 @@ class BaseGenerator(ABC, Generic[TGenerationResponseSample]):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(
|
async def generate(
|
||||||
self, preprocessed_dataset: List[TPreprocessedSample]
|
self, preprocessed_dataset: List[TPreprocessedSample]
|
||||||
) -> List[TGenerationResponseSample]:
|
) -> List[TGenerationResponseSample]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -231,7 +231,7 @@ class BaseTask(ABC):
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self, *args, **kwargs) -> EvalResult:
|
async def run(self, *args, **kwargs) -> EvalResult:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,25 +7,14 @@ import json
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import (
|
|
||||||
AggregateScorer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
from llama_stack.apis.dataset import * # noqa: F403
|
from llama_stack.apis.dataset import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
from .config import MetaReferenceEvalsImplConfig
|
||||||
from llama_stack.providers.impls.meta_reference.evals.processor.mmlu_processor import (
|
|
||||||
MMLUProcessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
# from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
||||||
# from .tasks.run_eval_task import RunEvalTask
|
from .tasks.run_eval_task import RunEvalTask
|
||||||
from .scorer.basic_scorers import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
from .config import MetaReferenceEvalsImplConfig
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceEvalsImpl(Evals):
|
class MetaReferenceEvalsImpl(Evals):
|
||||||
|
@ -70,58 +59,8 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: wrap inside task
|
run_task = RunEvalTask()
|
||||||
# run_task = RunEvalTask(
|
eval_result = await run_task.run(eval_task_config, self.inference_api)
|
||||||
# eval_task_config=eval_task_config,
|
|
||||||
# )
|
|
||||||
# eval_result = run_task.run()
|
|
||||||
|
|
||||||
dataset = DatasetRegistry.get_dataset(
|
|
||||||
eval_task_config.dataset_config.dataset_name
|
|
||||||
)
|
|
||||||
dataset.load(n_samples=eval_task_config.dataset_config.row_limit)
|
|
||||||
print(f"Running on {len(dataset)} samples")
|
|
||||||
|
|
||||||
# F1
|
|
||||||
processor = MMLUProcessor()
|
|
||||||
preprocessed = processor.preprocess(dataset)
|
|
||||||
|
|
||||||
# Generation
|
|
||||||
# TODO: wrap inside BaseGenerator
|
|
||||||
generation_outputs = []
|
|
||||||
for sample in preprocessed:
|
|
||||||
print("generation: ", sample)
|
|
||||||
response = await self.inference_api.chat_completion(
|
|
||||||
model=model,
|
|
||||||
messages=sample.generation_input.messages,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
cprint(f"response: {response}", "cyan")
|
|
||||||
|
|
||||||
generation_outputs.append(
|
|
||||||
GenerationResponseSample(
|
|
||||||
generation_output=GenerationOutput(
|
|
||||||
completion_message=response.completion_message.content
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cprint(generation_outputs, "green")
|
|
||||||
|
|
||||||
# F2
|
|
||||||
postprocessed = processor.postprocess(generation_outputs, dataset)
|
|
||||||
cprint(postprocessed, "blue")
|
|
||||||
|
|
||||||
# F3 - scorer
|
|
||||||
scorer = AggregateScorer(
|
|
||||||
scorers=[
|
|
||||||
AccuracyScorer(),
|
|
||||||
RandomScorer(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
scorer_results = scorer.score(postprocessed)
|
|
||||||
cprint(scorer_results, "magenta")
|
|
||||||
eval_result = scorer.aggregate_results(scorer_results)
|
|
||||||
|
|
||||||
return EvaluateResponse(
|
return EvaluateResponse(
|
||||||
eval_result=eval_result,
|
eval_result=eval_result,
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceGenerator(BaseGenerator[PreprocessedSample, GenerationResponseSample]):
|
||||||
|
"""
|
||||||
|
InferenceGenerator for LlamaStack
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
inference_api,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.model = model
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self, preprocessed_dataset: List[PreprocessedSample]
|
||||||
|
) -> List[GenerationResponseSample]:
|
||||||
|
generation_outputs = []
|
||||||
|
for sample in preprocessed_dataset:
|
||||||
|
print("generation: ", sample)
|
||||||
|
response = await self.inference_api.chat_completion(
|
||||||
|
model=self.model,
|
||||||
|
messages=sample.generation_input.messages,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
cprint(f"response: {response}", "cyan")
|
||||||
|
|
||||||
|
generation_outputs.append(
|
||||||
|
GenerationResponseSample(
|
||||||
|
generation_output=GenerationOutput(
|
||||||
|
completion_message=response.completion_message.content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return generation_outputs
|
|
@ -4,8 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||||
|
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403
|
||||||
|
from llama_stack.providers.impls.meta_reference.evals.generator.inference_generator import (
|
||||||
|
InferenceGenerator,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.impls.meta_reference.evals.processor.mmlu_processor import (
|
||||||
|
MMLUProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
class RunEvalTask(BaseTask):
|
class RunEvalTask(BaseTask):
|
||||||
|
@ -15,25 +24,51 @@ class RunEvalTask(BaseTask):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eval_task_config,
|
|
||||||
generator_processor: Optional[BaseGeneratorProcessor] = None,
|
|
||||||
generator: Optional[BaseGenerator] = None,
|
|
||||||
scorer: Optional[BaseScorer] = None,
|
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(*args, **kwargs)
|
||||||
generator_processor=generator_processor,
|
|
||||||
generator=generator,
|
async def run(
|
||||||
scorer=scorer,
|
self,
|
||||||
|
eval_task_config: EvaluateTaskConfig,
|
||||||
|
inference_api: Inference,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
) -> EvalResult:
|
||||||
self.eval_task_config = eval_task_config
|
print(f"Running eval task w/ {eval_task_config}")
|
||||||
self.dataset = DatasetRegistry.get_dataset(
|
|
||||||
|
dataset = DatasetRegistry.get_dataset(
|
||||||
eval_task_config.dataset_config.dataset_name
|
eval_task_config.dataset_config.dataset_name
|
||||||
)
|
)
|
||||||
|
dataset.load(n_samples=eval_task_config.dataset_config.row_limit)
|
||||||
|
print(f"Running on {len(dataset)} samples")
|
||||||
|
|
||||||
def run(self, *args, **kwargs) -> EvalResult:
|
# F1
|
||||||
print(f"Running eval task on {self.dataset}")
|
processor = MMLUProcessor()
|
||||||
return EvalResult()
|
preprocessed = processor.preprocess(dataset)
|
||||||
|
|
||||||
|
# Generation
|
||||||
|
generator = InferenceGenerator(
|
||||||
|
model=eval_task_config.generation_config.model,
|
||||||
|
inference_api=inference_api,
|
||||||
|
)
|
||||||
|
generation_outputs = await generator.generate(preprocessed)
|
||||||
|
|
||||||
|
# F2
|
||||||
|
postprocessed = processor.postprocess(generation_outputs, dataset)
|
||||||
|
cprint(postprocessed, "blue")
|
||||||
|
|
||||||
|
# F3 - scorer
|
||||||
|
scorer = AggregateScorer(
|
||||||
|
scorers=[
|
||||||
|
AccuracyScorer(),
|
||||||
|
RandomScorer(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
scorer_results = scorer.score(postprocessed)
|
||||||
|
cprint(scorer_results, "magenta")
|
||||||
|
eval_result = scorer.aggregate_results(scorer_results)
|
||||||
|
|
||||||
|
return eval_result
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue