mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-01 20:18:50 +00:00
[Evals API][11/n] huggingface dataset provider + mmlu scoring fn (#392)
* wip * scoring fn api * eval api * eval task * evaluate api update * pre commit * unwrap context -> config * config field doc * typo * naming fix * separate benchmark / app eval * api name * rename * wip tests * wip * datasetio test * delete unused * fixture * scoring resolve * fix scoring register * scoring test pass * score batch * scoring fix * fix eval * test eval works * huggingface provider * datasetdef files * mmlu scoring fn * test wip * remove type ignore * api refactor * add default task_eval_id for routing * add eval_id for jobs * remove type ignore * huggingface provider * wip huggingface register * only keep 1 run_eval * fix optional * register task required * register task required * delete old tests * fix * mmlu loose * refactor * msg * fix tests * move benchmark task def to file * msg * gen openapi * openapi gen * move dataset to hf llamastack repo * remove todo * refactor * add register model to unit test * rename * register to client * delete preregistered dataset/eval task * comments * huggingface -> remote adapter * openapi gen
This commit is contained in:
parent
b78ee3a0a5
commit
2b7d70ba86
20 changed files with 1607 additions and 718 deletions
|
@ -3,20 +3,17 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import io
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import unquote
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
|
||||
from .config import MetaReferenceDatasetIOConfig
|
||||
|
||||
|
@ -73,31 +70,9 @@ class PandasDataframeDataset(BaseDataset):
|
|||
if self.df is not None:
|
||||
return
|
||||
|
||||
# TODO: more robust support w/ data url
|
||||
if self.dataset_def.url.uri.endswith(".csv"):
|
||||
df = pandas.read_csv(self.dataset_def.url.uri)
|
||||
elif self.dataset_def.url.uri.endswith(".xlsx"):
|
||||
df = pandas.read_excel(self.dataset_def.url.uri)
|
||||
elif self.dataset_def.url.uri.startswith("data:"):
|
||||
parts = parse_data_url(self.dataset_def.url.uri)
|
||||
data = parts["data"]
|
||||
if parts["is_base64"]:
|
||||
data = base64.b64decode(data)
|
||||
else:
|
||||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
|
||||
mime_type = parts["mimetype"]
|
||||
mime_category = mime_type.split("/")[0]
|
||||
data_bytes = io.BytesIO(data)
|
||||
|
||||
if mime_category == "text":
|
||||
df = pandas.read_csv(data_bytes)
|
||||
else:
|
||||
df = pandas.read_excel(data_bytes)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
|
||||
df = get_dataframe_from_url(self.dataset_def.url)
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
|
||||
self.df = self._validate_dataset_schema(df)
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||
|
@ -47,7 +49,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
|
||||
self.eval_tasks = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
|
@ -93,7 +96,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
rows_in_page=(
|
||||
-1 if task_config.num_examples is None else task_config.num_examples
|
||||
),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
task_id=task_id,
|
||||
|
@ -125,7 +130,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
), "SamplingParams.max_tokens must be provided"
|
||||
|
||||
generations = []
|
||||
for x in input_rows:
|
||||
for x in tqdm(input_rows):
|
||||
if ColumnName.completion_input.value in x:
|
||||
input_content = eval(str(x[ColumnName.completion_input.value]))
|
||||
response = await self.inference_api.completion(
|
||||
|
|
|
@ -13,21 +13,14 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
|||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||
EqualityScoringFn,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
||||
LlmAsJudgeScoringFn,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||
SubsetOfScoringFn,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||
|
||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||
|
||||
|
|
|
@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef
|
|||
equality = ScoringFnDef(
|
||||
identifier="meta-reference::equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
|
|
@ -26,7 +26,6 @@ Total rating:
|
|||
llm_as_judge_8b_correctness = ScoringFnDef(
|
||||
identifier="meta-reference::llm_as_judge_8b_correctness",
|
||||
description="Llm As Judge Scoring Function",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
params=LLMAsJudgeScoringFnParams(
|
||||
prompt_template=JUDGE_PROMPT,
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
r"Answer\s*:",
|
||||
r"Answer\s*:", # Korean invisible character
|
||||
r"উত্তর\s*:",
|
||||
r"उत्तर\s*:",
|
||||
r"উত্তরঃ",
|
||||
r"উত্তর\s*:",
|
||||
r"Antwort\s*:",
|
||||
r"답변\s*:",
|
||||
r"정답\s*:",
|
||||
r"답\s*:",
|
||||
r"答案\s*:",
|
||||
r"答案\s*:",
|
||||
r"答\s*:",
|
||||
r"答\s*:",
|
||||
r"答复\s*:",
|
||||
r"答曰\s*:",
|
||||
r"الإجابة:",
|
||||
r"الجواب:",
|
||||
r"إجابة:",
|
||||
r"الإجابة النهائية:",
|
||||
r"الإجابة الصحيحة:",
|
||||
r"الإجابة الصحيحة هي:",
|
||||
r"الإجابة هي:",
|
||||
r"Respuesta\s*:",
|
||||
r"Risposta\s*:",
|
||||
r"答え\s*:",
|
||||
r"答え\s*:",
|
||||
r"回答\s*:",
|
||||
r"回答\s*:",
|
||||
r"解答\s*:",
|
||||
r"Jawaban\s*:",
|
||||
r"Réponse\s*:",
|
||||
r"Resposta\s*:",
|
||||
r"Jibu\s*:",
|
||||
r"Idahun\s*:",
|
||||
r"Ìdáhùn\s*:",
|
||||
r"Idáhùn\s*:",
|
||||
r"Àmọ̀nà\s*:",
|
||||
r"Àdáhùn\s*:",
|
||||
r"Ànúgọ\s*:",
|
||||
r"Àṣàyàn\s*:",
|
||||
]
|
||||
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
||||
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||
)
|
||||
|
||||
regex_parser_multiple_choice_answer = ScoringFnDef(
|
||||
identifier="meta-reference::regex_parser_multiple_choice_answer",
|
||||
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
|
||||
return_type=NumberType(),
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=[
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||
for x in MULTILINGUAL_ANSWER_REGEXES
|
||||
],
|
||||
),
|
||||
)
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
|
||||
from .base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from .common import aggregate_accuracy
|
||||
|
||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||
regex_parser_multiple_choice_answer,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserScoringFn(BaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert (
|
||||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert (
|
||||
fn_def.params is not None
|
||||
and fn_def.params.type == ScoringConfigType.regex_parser.value
|
||||
), f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
# parse answer according to regex
|
||||
parsed_answer = None
|
||||
for regex in fn_def.params.parsing_regexes:
|
||||
match = re.search(regex, generated_answer)
|
||||
if match:
|
||||
parsed_answer = match.group(1)
|
||||
break
|
||||
|
||||
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
Loading…
Add table
Add a link
Reference in a new issue