mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
add data structure to tasks
This commit is contained in:
parent
9816c9aae6
commit
ad18dc94ac
7 changed files with 100 additions and 168 deletions
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||
|
|
@ -42,30 +44,37 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
dataset = DatasetRegistry.get_dataset(dataset)
|
||||
dataset.load()
|
||||
|
||||
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||
x1 = task_impl.preprocess()
|
||||
task_impl = TaskRegistry.get_task(task)()
|
||||
preprocessed = task_impl.preprocess(dataset)
|
||||
|
||||
# TODO: replace w/ batch inference & async return eval job
|
||||
generation_outputs = []
|
||||
if eval_task_config is None:
|
||||
eval_task_config = EvaluateTaskConfig(n_samples=len(x1))
|
||||
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(x1):
|
||||
eval_task_config.n_samples = len(x1)
|
||||
eval_task_config = EvaluateTaskConfig(n_samples=len(preprocessed))
|
||||
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(
|
||||
preprocessed
|
||||
):
|
||||
eval_task_config.n_samples = len(preprocessed)
|
||||
|
||||
print(
|
||||
f"Eval generation start, generate on {eval_task_config.n_samples} samples"
|
||||
)
|
||||
|
||||
for msg in x1[: eval_task_config.n_samples]:
|
||||
print("generation for msg: ", msg)
|
||||
for sample in preprocessed[: eval_task_config.n_samples]:
|
||||
print("generation: ", sample)
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=model,
|
||||
messages=[msg],
|
||||
messages=sample.preprocessed["messages"],
|
||||
stream=False,
|
||||
)
|
||||
generation_outputs.append(response.completion_message.content)
|
||||
sample.prediction = PredictionSample(
|
||||
completion_message=response.completion_message.content
|
||||
)
|
||||
generation_outputs.append(sample)
|
||||
|
||||
x2 = task_impl.postprocess(generation_outputs)
|
||||
eval_results = task_impl.score(x2)
|
||||
eval_response = task_impl.aggregate_results(eval_results)
|
||||
return eval_response
|
||||
postprocessed = task_impl.postprocess(generation_outputs)
|
||||
eval_results = task_impl.score(postprocessed)
|
||||
aggr_result = task_impl.aggregate_results(eval_results)
|
||||
return EvaluateResponse(
|
||||
eval_result=aggr_result,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue