mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
[Evals API][11/n] huggingface dataset provider + mmlu scoring fn (#392)
* wip * scoring fn api * eval api * eval task * evaluate api update * pre commit * unwrap context -> config * config field doc * typo * naming fix * separate benchmark / app eval * api name * rename * wip tests * wip * datasetio test * delete unused * fixture * scoring resolve * fix scoring register * scoring test pass * score batch * scoring fix * fix eval * test eval works * huggingface provider * datasetdef files * mmlu scoring fn * test wip * remove type ignore * api refactor * add default task_eval_id for routing * add eval_id for jobs * remove type ignore * huggingface provider * wip huggingface register * only keep 1 run_eval * fix optional * register task required * register task required * delete old tests * fix * mmlu loose * refactor * msg * fix tests * move benchmark task def to file * msg * gen openapi * openapi gen * move dataset to hf llamastack repo * remove todo * refactor * add register model to unit test * rename * register to client * delete preregistered dataset/eval task * comments * huggingface -> remote adapter * openapi gen
This commit is contained in:
parent
b78ee3a0a5
commit
2b7d70ba86
20 changed files with 1607 additions and 718 deletions
|
@ -49,6 +49,7 @@ from llama_stack.apis.models import * # noqa: F403
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.apis.shields import * # noqa: F403
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
from llama_stack.apis.inspect import * # noqa: F403
|
from llama_stack.apis.inspect import * # noqa: F403
|
||||||
|
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
|
@ -63,6 +64,7 @@ class LlamaStack(
|
||||||
PostTraining,
|
PostTraining,
|
||||||
Memory,
|
Memory,
|
||||||
Eval,
|
Eval,
|
||||||
|
EvalTasks,
|
||||||
Scoring,
|
Scoring,
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
DatasetIO,
|
DatasetIO,
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -40,6 +40,10 @@ EvalCandidate = Annotated[
|
||||||
class BenchmarkEvalTaskConfig(BaseModel):
|
class BenchmarkEvalTaskConfig(BaseModel):
|
||||||
type: Literal["benchmark"] = "benchmark"
|
type: Literal["benchmark"] = "benchmark"
|
||||||
eval_candidate: EvalCandidate
|
eval_candidate: EvalCandidate
|
||||||
|
num_examples: Optional[int] = Field(
|
||||||
|
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -50,6 +54,10 @@ class AppEvalTaskConfig(BaseModel):
|
||||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
|
num_examples: Optional[int] = Field(
|
||||||
|
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
# we could optinally add any specific dataset config here
|
# we could optinally add any specific dataset config here
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import HuggingfaceDatasetIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(
|
||||||
|
config: HuggingfaceDatasetIOConfig,
|
||||||
|
_deps,
|
||||||
|
):
|
||||||
|
from .huggingface import HuggingfaceDatasetIOImpl
|
||||||
|
|
||||||
|
impl = HuggingfaceDatasetIOImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.datasetio import * # noqa: F401, F403
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceDatasetIOConfig(BaseModel): ...
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
import datasets as hf_datasets
|
||||||
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
|
|
||||||
|
from .config import HuggingfaceDatasetIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_dataset(dataset_def: DatasetDef):
|
||||||
|
if dataset_def.metadata.get("path", None):
|
||||||
|
return hf_datasets.load_dataset(**dataset_def.metadata)
|
||||||
|
|
||||||
|
df = get_dataframe_from_url(dataset_def.url)
|
||||||
|
|
||||||
|
if df is None:
|
||||||
|
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
||||||
|
|
||||||
|
dataset = hf_datasets.Dataset.from_pandas(df)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
# local registry for keeping track of datasets within the provider
|
||||||
|
self.dataset_infos = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
dataset_def: DatasetDef,
|
||||||
|
) -> None:
|
||||||
|
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||||
|
|
||||||
|
async def list_datasets(self) -> List[DatasetDef]:
|
||||||
|
return list(self.dataset_infos.values())
|
||||||
|
|
||||||
|
async def get_rows_paginated(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
rows_in_page: int,
|
||||||
|
page_token: Optional[str] = None,
|
||||||
|
filter_condition: Optional[str] = None,
|
||||||
|
) -> PaginatedRowsResult:
|
||||||
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
|
loaded_dataset = load_hf_dataset(dataset_def)
|
||||||
|
|
||||||
|
if page_token and not page_token.isnumeric():
|
||||||
|
raise ValueError("Invalid page_token")
|
||||||
|
|
||||||
|
if page_token is None or len(page_token) == 0:
|
||||||
|
next_page_token = 0
|
||||||
|
else:
|
||||||
|
next_page_token = int(page_token)
|
||||||
|
|
||||||
|
start = next_page_token
|
||||||
|
if rows_in_page == -1:
|
||||||
|
end = len(loaded_dataset)
|
||||||
|
else:
|
||||||
|
end = min(start + rows_in_page, len(loaded_dataset))
|
||||||
|
|
||||||
|
rows = [loaded_dataset[i] for i in range(start, end)]
|
||||||
|
|
||||||
|
return PaginatedRowsResult(
|
||||||
|
rows=rows,
|
||||||
|
total_count=len(rows),
|
||||||
|
next_page_token=str(end),
|
||||||
|
)
|
|
@ -3,20 +3,17 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import io
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
import base64
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from urllib.parse import unquote
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
|
|
||||||
from .config import MetaReferenceDatasetIOConfig
|
from .config import MetaReferenceDatasetIOConfig
|
||||||
|
|
||||||
|
@ -73,31 +70,9 @@ class PandasDataframeDataset(BaseDataset):
|
||||||
if self.df is not None:
|
if self.df is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: more robust support w/ data url
|
df = get_dataframe_from_url(self.dataset_def.url)
|
||||||
if self.dataset_def.url.uri.endswith(".csv"):
|
if df is None:
|
||||||
df = pandas.read_csv(self.dataset_def.url.uri)
|
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||||
elif self.dataset_def.url.uri.endswith(".xlsx"):
|
|
||||||
df = pandas.read_excel(self.dataset_def.url.uri)
|
|
||||||
elif self.dataset_def.url.uri.startswith("data:"):
|
|
||||||
parts = parse_data_url(self.dataset_def.url.uri)
|
|
||||||
data = parts["data"]
|
|
||||||
if parts["is_base64"]:
|
|
||||||
data = base64.b64decode(data)
|
|
||||||
else:
|
|
||||||
data = unquote(data)
|
|
||||||
encoding = parts["encoding"] or "utf-8"
|
|
||||||
data = data.encode(encoding)
|
|
||||||
|
|
||||||
mime_type = parts["mimetype"]
|
|
||||||
mime_category = mime_type.split("/")[0]
|
|
||||||
data_bytes = io.BytesIO(data)
|
|
||||||
|
|
||||||
if mime_category == "text":
|
|
||||||
df = pandas.read_csv(data_bytes)
|
|
||||||
else:
|
|
||||||
df = pandas.read_excel(data_bytes)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
|
|
||||||
|
|
||||||
self.df = self._validate_dataset_schema(df)
|
self.df = self._validate_dataset_schema(df)
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from .....apis.common.job_types import Job
|
from .....apis.common.job_types import Job
|
||||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||||
|
@ -47,7 +49,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
|
|
||||||
self.eval_tasks = {}
|
self.eval_tasks = {}
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
@ -93,7 +96,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=(
|
||||||
|
-1 if task_config.num_examples is None else task_config.num_examples
|
||||||
|
),
|
||||||
)
|
)
|
||||||
res = await self.evaluate_rows(
|
res = await self.evaluate_rows(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
@ -125,7 +130,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
), "SamplingParams.max_tokens must be provided"
|
), "SamplingParams.max_tokens must be provided"
|
||||||
|
|
||||||
generations = []
|
generations = []
|
||||||
for x in input_rows:
|
for x in tqdm(input_rows):
|
||||||
if ColumnName.completion_input.value in x:
|
if ColumnName.completion_input.value in x:
|
||||||
input_content = eval(str(x[ColumnName.completion_input.value]))
|
input_content = eval(str(x[ColumnName.completion_input.value]))
|
||||||
response = await self.inference_api.completion(
|
response = await self.inference_api.completion(
|
||||||
|
|
|
@ -13,21 +13,14 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.inference.inference import Inference
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
|
||||||
EqualityScoringFn,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
|
||||||
LlmAsJudgeScoringFn,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
|
||||||
SubsetOfScoringFn,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import MetaReferenceScoringConfig
|
from .config import MetaReferenceScoringConfig
|
||||||
|
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||||
|
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||||
|
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||||
|
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||||
|
|
||||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
|
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||||
|
|
||||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
equality = ScoringFnDef(
|
equality = ScoringFnDef(
|
||||||
identifier="meta-reference::equality",
|
identifier="meta-reference::equality",
|
||||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,6 @@ Total rating:
|
||||||
llm_as_judge_8b_correctness = ScoringFnDef(
|
llm_as_judge_8b_correctness = ScoringFnDef(
|
||||||
identifier="meta-reference::llm_as_judge_8b_correctness",
|
identifier="meta-reference::llm_as_judge_8b_correctness",
|
||||||
description="Llm As Judge Scoring Function",
|
description="Llm As Judge Scoring Function",
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
params=LLMAsJudgeScoringFnParams(
|
params=LLMAsJudgeScoringFnParams(
|
||||||
prompt_template=JUDGE_PROMPT,
|
prompt_template=JUDGE_PROMPT,
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
|
||||||
|
MULTILINGUAL_ANSWER_REGEXES = [
|
||||||
|
r"Answer\s*:",
|
||||||
|
r"Answer\s*:", # Korean invisible character
|
||||||
|
r"উত্তর\s*:",
|
||||||
|
r"उत्तर\s*:",
|
||||||
|
r"উত্তরঃ",
|
||||||
|
r"উত্তর\s*:",
|
||||||
|
r"Antwort\s*:",
|
||||||
|
r"답변\s*:",
|
||||||
|
r"정답\s*:",
|
||||||
|
r"답\s*:",
|
||||||
|
r"答案\s*:",
|
||||||
|
r"答案\s*:",
|
||||||
|
r"答\s*:",
|
||||||
|
r"答\s*:",
|
||||||
|
r"答复\s*:",
|
||||||
|
r"答曰\s*:",
|
||||||
|
r"الإجابة:",
|
||||||
|
r"الجواب:",
|
||||||
|
r"إجابة:",
|
||||||
|
r"الإجابة النهائية:",
|
||||||
|
r"الإجابة الصحيحة:",
|
||||||
|
r"الإجابة الصحيحة هي:",
|
||||||
|
r"الإجابة هي:",
|
||||||
|
r"Respuesta\s*:",
|
||||||
|
r"Risposta\s*:",
|
||||||
|
r"答え\s*:",
|
||||||
|
r"答え\s*:",
|
||||||
|
r"回答\s*:",
|
||||||
|
r"回答\s*:",
|
||||||
|
r"解答\s*:",
|
||||||
|
r"Jawaban\s*:",
|
||||||
|
r"Réponse\s*:",
|
||||||
|
r"Resposta\s*:",
|
||||||
|
r"Jibu\s*:",
|
||||||
|
r"Idahun\s*:",
|
||||||
|
r"Ìdáhùn\s*:",
|
||||||
|
r"Idáhùn\s*:",
|
||||||
|
r"Àmọ̀nà\s*:",
|
||||||
|
r"Àdáhùn\s*:",
|
||||||
|
r"Ànúgọ\s*:",
|
||||||
|
r"Àṣàyàn\s*:",
|
||||||
|
]
|
||||||
|
|
||||||
|
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
||||||
|
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||||
|
)
|
||||||
|
|
||||||
|
regex_parser_multiple_choice_answer = ScoringFnDef(
|
||||||
|
identifier="meta-reference::regex_parser_multiple_choice_answer",
|
||||||
|
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
|
||||||
|
return_type=NumberType(),
|
||||||
|
params=RegexParserScoringFnParams(
|
||||||
|
parsing_regexes=[
|
||||||
|
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||||
|
for x in MULTILINGUAL_ANSWER_REGEXES
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import re
|
||||||
|
|
||||||
|
from .base_scoring_fn import BaseScoringFn
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from .common import aggregate_accuracy
|
||||||
|
|
||||||
|
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||||
|
regex_parser_multiple_choice_answer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RegexParserScoringFn(BaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
assert (
|
||||||
|
scoring_fn_identifier is not None
|
||||||
|
), "Scoring function identifier not found."
|
||||||
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
|
if scoring_params is not None:
|
||||||
|
fn_def.params = scoring_params
|
||||||
|
|
||||||
|
assert (
|
||||||
|
fn_def.params is not None
|
||||||
|
and fn_def.params.type == ScoringConfigType.regex_parser.value
|
||||||
|
), f"RegexParserScoringFnParams not found for {fn_def}."
|
||||||
|
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
||||||
|
# parse answer according to regex
|
||||||
|
parsed_answer = None
|
||||||
|
for regex in fn_def.params.parsing_regexes:
|
||||||
|
match = re.search(regex, generated_answer)
|
||||||
|
if match:
|
||||||
|
parsed_answer = match.group(1)
|
||||||
|
break
|
||||||
|
|
||||||
|
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
|
||||||
|
return {
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return aggregate_accuracy(scoring_results)
|
|
@ -19,4 +19,15 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.meta_reference.datasetio.MetaReferenceDatasetIOConfig",
|
config_class="llama_stack.providers.inline.meta_reference.datasetio.MetaReferenceDatasetIOConfig",
|
||||||
api_dependencies=[],
|
api_dependencies=[],
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.datasetio,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="huggingface",
|
||||||
|
pip_packages=[
|
||||||
|
"datasets",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.adapters.datasetio.huggingface",
|
||||||
|
config_class="llama_stack.providers.adapters.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -31,7 +31,20 @@ def datasetio_meta_reference() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
DATASETIO_FIXTURES = ["meta_reference", "remote"]
|
@pytest.fixture(scope="session")
|
||||||
|
def datasetio_huggingface() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="huggingface",
|
||||||
|
provider_type="remote::huggingface",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DATASETIO_FIXTURES = ["meta_reference", "remote", "huggingface"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
|
@ -34,6 +34,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
id="meta_reference_eval_together_inference",
|
id="meta_reference_eval_together_inference",
|
||||||
marks=pytest.mark.meta_reference_eval_together_inference,
|
marks=pytest.mark.meta_reference_eval_together_inference,
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"eval": "meta_reference",
|
||||||
|
"scoring": "meta_reference",
|
||||||
|
"datasetio": "huggingface",
|
||||||
|
"inference": "together",
|
||||||
|
},
|
||||||
|
id="meta_reference_eval_together_inference_huggingface_datasetio",
|
||||||
|
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,6 +51,7 @@ def pytest_configure(config):
|
||||||
for fixture_name in [
|
for fixture_name in [
|
||||||
"meta_reference_eval_fireworks_inference",
|
"meta_reference_eval_fireworks_inference",
|
||||||
"meta_reference_eval_together_inference",
|
"meta_reference_eval_together_inference",
|
||||||
|
"meta_reference_eval_together_inference_huggingface_datasetio",
|
||||||
]:
|
]:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers",
|
"markers",
|
||||||
|
|
|
@ -7,10 +7,15 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_models.llama3.api import SamplingParams
|
from llama_models.llama3.api import SamplingParams, URL
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
|
||||||
|
|
||||||
|
from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider
|
||||||
|
|
||||||
from llama_stack.apis.eval.eval import (
|
from llama_stack.apis.eval.eval import (
|
||||||
AppEvalTaskConfig,
|
AppEvalTaskConfig,
|
||||||
|
BenchmarkEvalTaskConfig,
|
||||||
EvalTaskDefWithProvider,
|
EvalTaskDefWithProvider,
|
||||||
ModelCandidate,
|
ModelCandidate,
|
||||||
)
|
)
|
||||||
|
@ -21,7 +26,7 @@ from llama_stack.providers.tests.datasetio.test_datasetio import register_datase
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
# pytest llama_stack/providers/tests/eval/test_eval.py
|
# pytest llama_stack/providers/tests/eval/test_eval.py
|
||||||
# -m "meta_reference"
|
# -m "meta_reference_eval_together_inference_huggingface_datasetio"
|
||||||
# -v -s --tb=short --disable-warnings
|
# -v -s --tb=short --disable-warnings
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,21 +38,26 @@ class Testeval:
|
||||||
eval_tasks_impl = eval_stack[Api.eval_tasks]
|
eval_tasks_impl = eval_stack[Api.eval_tasks]
|
||||||
response = await eval_tasks_impl.list_eval_tasks()
|
response = await eval_tasks_impl.list_eval_tasks()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_eval_evaluate_rows(self, eval_stack):
|
async def test_eval_evaluate_rows(self, eval_stack):
|
||||||
eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = (
|
eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl, models_impl = (
|
||||||
eval_stack[Api.eval],
|
eval_stack[Api.eval],
|
||||||
eval_stack[Api.eval_tasks],
|
eval_stack[Api.eval_tasks],
|
||||||
eval_stack[Api.datasetio],
|
eval_stack[Api.datasetio],
|
||||||
eval_stack[Api.datasets],
|
eval_stack[Api.datasets],
|
||||||
|
eval_stack[Api.models],
|
||||||
)
|
)
|
||||||
|
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
|
||||||
|
await models_impl.register_model(
|
||||||
|
model_id=model_id,
|
||||||
|
provider_id="",
|
||||||
|
)
|
||||||
await register_dataset(
|
await register_dataset(
|
||||||
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
|
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
|
||||||
)
|
)
|
||||||
response = await datasets_impl.list_datasets()
|
response = await datasets_impl.list_datasets()
|
||||||
assert len(response) == 1
|
|
||||||
rows = await datasetio_impl.get_rows_paginated(
|
rows = await datasetio_impl.get_rows_paginated(
|
||||||
dataset_id="test_dataset_for_eval",
|
dataset_id="test_dataset_for_eval",
|
||||||
rows_in_page=3,
|
rows_in_page=3,
|
||||||
|
@ -66,7 +76,6 @@ class Testeval:
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
)
|
)
|
||||||
await eval_tasks_impl.register_eval_task(task_def)
|
await eval_tasks_impl.register_eval_task(task_def)
|
||||||
|
|
||||||
response = await eval_impl.evaluate_rows(
|
response = await eval_impl.evaluate_rows(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
|
@ -84,11 +93,17 @@ class Testeval:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_eval_run_eval(self, eval_stack):
|
async def test_eval_run_eval(self, eval_stack):
|
||||||
eval_impl, eval_tasks_impl, datasets_impl = (
|
eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
|
||||||
eval_stack[Api.eval],
|
eval_stack[Api.eval],
|
||||||
eval_stack[Api.eval_tasks],
|
eval_stack[Api.eval_tasks],
|
||||||
eval_stack[Api.datasets],
|
eval_stack[Api.datasets],
|
||||||
|
eval_stack[Api.models],
|
||||||
)
|
)
|
||||||
|
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
|
||||||
|
await models_impl.register_model(
|
||||||
|
model_id=model_id,
|
||||||
|
provider_id="",
|
||||||
|
)
|
||||||
await register_dataset(
|
await register_dataset(
|
||||||
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
|
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
|
||||||
)
|
)
|
||||||
|
@ -124,3 +139,72 @@ class Testeval:
|
||||||
assert len(eval_response.generations) == 5
|
assert len(eval_response.generations) == 5
|
||||||
assert "meta-reference::subset_of" in eval_response.scores
|
assert "meta-reference::subset_of" in eval_response.scores
|
||||||
assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores
|
assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_eval_run_benchmark_eval(self, eval_stack):
|
||||||
|
eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
|
||||||
|
eval_stack[Api.eval],
|
||||||
|
eval_stack[Api.eval_tasks],
|
||||||
|
eval_stack[Api.datasets],
|
||||||
|
eval_stack[Api.models],
|
||||||
|
)
|
||||||
|
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
|
||||||
|
await models_impl.register_model(
|
||||||
|
model_id=model_id,
|
||||||
|
provider_id="",
|
||||||
|
)
|
||||||
|
response = await datasets_impl.list_datasets()
|
||||||
|
assert len(response) > 0
|
||||||
|
if response[0].provider_id != "huggingface":
|
||||||
|
pytest.skip(
|
||||||
|
"Only huggingface provider supports pre-registered remote datasets"
|
||||||
|
)
|
||||||
|
# register dataset
|
||||||
|
mmlu = DatasetDefWithProvider(
|
||||||
|
identifier="mmlu",
|
||||||
|
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
|
||||||
|
dataset_schema={
|
||||||
|
"input_query": StringType(),
|
||||||
|
"expected_answer": StringType(),
|
||||||
|
"chat_completion_input": ChatCompletionInputType(),
|
||||||
|
},
|
||||||
|
metadata={
|
||||||
|
"path": "llamastack/evals",
|
||||||
|
"name": "evals__mmlu__details",
|
||||||
|
"split": "train",
|
||||||
|
},
|
||||||
|
provider_id="",
|
||||||
|
)
|
||||||
|
|
||||||
|
await datasets_impl.register_dataset(mmlu)
|
||||||
|
|
||||||
|
# register eval task
|
||||||
|
meta_reference_mmlu = EvalTaskDefWithProvider(
|
||||||
|
identifier="meta-reference-mmlu",
|
||||||
|
dataset_id="mmlu",
|
||||||
|
scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"],
|
||||||
|
provider_id="",
|
||||||
|
)
|
||||||
|
|
||||||
|
await eval_tasks_impl.register_eval_task(meta_reference_mmlu)
|
||||||
|
|
||||||
|
# list benchmarks
|
||||||
|
response = await eval_tasks_impl.list_eval_tasks()
|
||||||
|
assert len(response) > 0
|
||||||
|
|
||||||
|
benchmark_id = "meta-reference-mmlu"
|
||||||
|
response = await eval_impl.run_eval(
|
||||||
|
task_id=benchmark_id,
|
||||||
|
task_config=BenchmarkEvalTaskConfig(
|
||||||
|
eval_candidate=ModelCandidate(
|
||||||
|
model="Llama3.2-3B-Instruct",
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
),
|
||||||
|
num_examples=3,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
job_status = await eval_impl.job_status(benchmark_id, response.job_id)
|
||||||
|
assert job_status and job_status.value == "completed"
|
||||||
|
eval_response = await eval_impl.job_result(benchmark_id, response.job_id)
|
||||||
|
assert eval_response is not None
|
||||||
|
assert len(eval_response.generations) == 3
|
||||||
|
|
5
llama_stack/providers/utils/datasetio/__init__.py
Normal file
5
llama_stack/providers/utils/datasetio/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
45
llama_stack/providers/utils/datasetio/url_utils.py
Normal file
45
llama_stack/providers/utils/datasetio/url_utils.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from urllib.parse import unquote
|
||||||
|
|
||||||
|
import pandas
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataframe_from_url(url: URL):
|
||||||
|
df = None
|
||||||
|
if url.uri.endswith(".csv"):
|
||||||
|
df = pandas.read_csv(url.uri)
|
||||||
|
elif url.uri.endswith(".xlsx"):
|
||||||
|
df = pandas.read_excel(url.uri)
|
||||||
|
elif url.uri.startswith("data:"):
|
||||||
|
parts = parse_data_url(url.uri)
|
||||||
|
data = parts["data"]
|
||||||
|
if parts["is_base64"]:
|
||||||
|
data = base64.b64decode(data)
|
||||||
|
else:
|
||||||
|
data = unquote(data)
|
||||||
|
encoding = parts["encoding"] or "utf-8"
|
||||||
|
data = data.encode(encoding)
|
||||||
|
|
||||||
|
mime_type = parts["mimetype"]
|
||||||
|
mime_category = mime_type.split("/")[0]
|
||||||
|
data_bytes = io.BytesIO(data)
|
||||||
|
|
||||||
|
if mime_category == "text":
|
||||||
|
df = pandas.read_csv(data_bytes)
|
||||||
|
else:
|
||||||
|
df = pandas.read_excel(data_bytes)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file type: {url}")
|
||||||
|
|
||||||
|
return df
|
Loading…
Add table
Add a link
Reference in a new issue