[Evals API][11/n] huggingface dataset provider + mmlu scoring fn (#392)

* wip

* scoring fn api

* eval api

* eval task

* evaluate api update

* pre commit

* unwrap context -> config

* config field doc

* typo

* naming fix

* separate benchmark / app eval

* api name

* rename

* wip tests

* wip

* datasetio test

* delete unused

* fixture

* scoring resolve

* fix scoring register

* scoring test pass

* score batch

* scoring fix

* fix eval

* test eval works

* huggingface provider

* datasetdef files

* mmlu scoring fn

* test wip

* remove type ignore

* api refactor

* add default task_eval_id for routing

* add eval_id for jobs

* remove type ignore

* huggingface provider

* wip huggingface register

* only keep 1 run_eval

* fix optional

* register task required

* register task required

* delete old tests

* fix

* mmlu loose

* refactor

* msg

* fix tests

* move benchmark task def to file

* msg

* gen openapi

* openapi gen

* move dataset to hf llamastack repo

* remove todo

* refactor

* add register model to unit test

* rename

* register to client

* delete preregistered dataset/eval task

* comments

* huggingface -> remote adapter

* openapi gen
This commit is contained in:
Xi Yan 2024-11-11 14:49:50 -05:00 committed by GitHub
parent b78ee3a0a5
commit 2b7d70ba86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1607 additions and 718 deletions

View file

@ -9,6 +9,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from tqdm import tqdm
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTaskDef
@ -47,7 +49,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
self.eval_tasks = {}
async def initialize(self) -> None: ...
async def initialize(self) -> None:
pass
async def shutdown(self) -> None: ...
@ -93,7 +96,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
rows_in_page=(
-1 if task_config.num_examples is None else task_config.num_examples
),
)
res = await self.evaluate_rows(
task_id=task_id,
@ -125,7 +130,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
), "SamplingParams.max_tokens must be provided"
generations = []
for x in input_rows:
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value]))
response = await self.inference_api.completion(