[Evals API][11/n] huggingface dataset provider + mmlu scoring fn (#392)

* wip

* scoring fn api

* eval api

* eval task

* evaluate api update

* pre commit

* unwrap context -> config

* config field doc

* typo

* naming fix

* separate benchmark / app eval

* api name

* rename

* wip tests

* wip

* datasetio test

* delete unused

* fixture

* scoring resolve

* fix scoring register

* scoring test pass

* score batch

* scoring fix

* fix eval

* test eval works

* huggingface provider

* datasetdef files

* mmlu scoring fn

* test wip

* remove type ignore

* api refactor

* add default task_eval_id for routing

* add eval_id for jobs

* remove type ignore

* huggingface provider

* wip huggingface register

* only keep 1 run_eval

* fix optional

* register task required

* register task required

* delete old tests

* fix

* mmlu loose

* refactor

* msg

* fix tests

* move benchmark task def to file

* msg

* gen openapi

* openapi gen

* move dataset to hf llamastack repo

* remove todo

* refactor

* add register model to unit test

* rename

* register to client

* delete preregistered dataset/eval task

* comments

* huggingface -> remote adapter

* openapi gen
This commit is contained in:
Xi Yan 2024-11-11 14:49:50 -05:00 committed by GitHub
parent b78ee3a0a5
commit 2b7d70ba86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1607 additions and 718 deletions

View file

@ -40,6 +40,10 @@ EvalCandidate = Annotated[
class BenchmarkEvalTaskConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
@json_schema_type
@ -50,6 +54,10 @@ class AppEvalTaskConfig(BaseModel):
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import HuggingfaceDatasetIOConfig
async def get_adapter_impl(
config: HuggingfaceDatasetIOConfig,
_deps,
):
from .huggingface import HuggingfaceDatasetIOImpl
impl = HuggingfaceDatasetIOImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import * # noqa: F401, F403
class HuggingfaceDatasetIOConfig(BaseModel): ...

View file

@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_stack.apis.datasetio import * # noqa: F403
import datasets as hf_datasets
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from .config import HuggingfaceDatasetIOConfig
def load_hf_dataset(dataset_def: DatasetDef):
if dataset_def.metadata.get("path", None):
return hf_datasets.load_dataset(**dataset_def.metadata)
df = get_dataframe_from_url(dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = hf_datasets.Dataset.from_pandas(df)
return dataset
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos = {}
async def initialize(self) -> None:
pass
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset_def: DatasetDef,
) -> None:
self.dataset_infos[dataset_def.identifier] = dataset_def
async def list_datasets(self) -> List[DatasetDef]:
return list(self.dataset_infos.values())
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
end = len(loaded_dataset)
else:
end = min(start + rows_in_page, len(loaded_dataset))
rows = [loaded_dataset[i] for i in range(start, end)]
return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
)

View file

@ -3,20 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
from typing import List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64
from abc import ABC, abstractmethod
from dataclasses import dataclass
from urllib.parse import unquote
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from .config import MetaReferenceDatasetIOConfig
@ -73,31 +70,9 @@ class PandasDataframeDataset(BaseDataset):
if self.df is not None:
return
# TODO: more robust support w/ data url
if self.dataset_def.url.uri.endswith(".csv"):
df = pandas.read_csv(self.dataset_def.url.uri)
elif self.dataset_def.url.uri.endswith(".xlsx"):
df = pandas.read_excel(self.dataset_def.url.uri)
elif self.dataset_def.url.uri.startswith("data:"):
parts = parse_data_url(self.dataset_def.url.uri)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
data_bytes = io.BytesIO(data)
if mime_category == "text":
df = pandas.read_csv(data_bytes)
else:
df = pandas.read_excel(data_bytes)
else:
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
df = get_dataframe_from_url(self.dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
self.df = self._validate_dataset_schema(df)

View file

@ -9,6 +9,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from tqdm import tqdm
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTaskDef
@ -47,7 +49,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
self.eval_tasks = {}
async def initialize(self) -> None: ...
async def initialize(self) -> None:
pass
async def shutdown(self) -> None: ...
@ -93,7 +96,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
rows_in_page=(
-1 if task_config.num_examples is None else task_config.num_examples
),
)
res = await self.evaluate_rows(
task_id=task_id,
@ -125,7 +130,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
), "SamplingParams.max_tokens must be provided"
generations = []
for x in input_rows:
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value]))
response = await self.inference_api.completion(

View file

@ -13,21 +13,14 @@ from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
EqualityScoringFn,
)
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
LlmAsJudgeScoringFn,
)
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
SubsetOfScoringFn,
)
from .config import MetaReferenceScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]

View file

@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef
equality = ScoringFnDef(
identifier="meta-reference::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)

View file

@ -26,7 +26,6 @@ Total rating:
llm_as_judge_8b_correctness = ScoringFnDef(
identifier="meta-reference::llm_as_judge_8b_correctness",
description="Llm As Judge Scoring Function",
parameters=[],
return_type=NumberType(),
params=LLMAsJudgeScoringFnParams(
prompt_template=JUDGE_PROMPT,

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:",
r"Answer\s*:", # Korean invisible character
r"উত্তর\s*:",
r"उत्तर\s*:",
r"উত্তরঃ",
r"উত্তর\s*:",
r"Antwort\s*:",
r"답변\s*:",
r"정답\s*:",
r"\s*:",
r"答案\s*",
r"答案\s*:",
r"\s*",
r"\s*:",
r"答复\s*",
r"答曰\s*",
r"الإجابة:",
r"الجواب:",
r"إجابة:",
r"الإجابة النهائية:",
r"الإجابة الصحيحة:",
r"الإجابة الصحيحة هي:",
r"الإجابة هي:",
r"Respuesta\s*:",
r"Risposta\s*:",
r"答え\s*:",
r"答え\s*",
r"回答\s*:",
r"回答\s*",
r"解答\s*:",
r"Jawaban\s*:",
r"Réponse\s*:",
r"Resposta\s*:",
r"Jibu\s*:",
r"Idahun\s*:",
r"Ìdáhùn\s*:",
r"Idáhùn\s*:",
r"Àmọ̀nà\s*:",
r"Àdáhùn\s*:",
r"Ànúgọ\s*:",
r"Àṣàyàn\s*:",
]
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
)
regex_parser_multiple_choice_answer = ScoringFnDef(
identifier="meta-reference::regex_parser_multiple_choice_answer",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(),
params=RegexParserScoringFnParams(
parsing_regexes=[
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
),
)

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from .common import aggregate_accuracy
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)
class RegexParserScoringFn(BaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert (
fn_def.params is not None
and fn_def.params.type == ScoringConfigType.regex_parser.value
), f"RegexParserScoringFnParams not found for {fn_def}."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
# parse answer according to regex
parsed_answer = None
for regex in fn_def.params.parsing_regexes:
match = re.search(regex, generated_answer)
if match:
parsed_answer = match.group(1)
break
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -19,4 +19,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.meta_reference.datasetio.MetaReferenceDatasetIOConfig",
api_dependencies=[],
),
remote_provider_spec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="huggingface",
pip_packages=[
"datasets",
],
module="llama_stack.providers.adapters.datasetio.huggingface",
config_class="llama_stack.providers.adapters.datasetio.huggingface.HuggingfaceDatasetIOConfig",
),
),
]

View file

@ -31,7 +31,20 @@ def datasetio_meta_reference() -> ProviderFixture:
)
DATASETIO_FIXTURES = ["meta_reference", "remote"]
@pytest.fixture(scope="session")
def datasetio_huggingface() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="huggingface",
provider_type="remote::huggingface",
config={},
)
],
)
DATASETIO_FIXTURES = ["meta_reference", "remote", "huggingface"]
@pytest_asyncio.fixture(scope="session")

View file

@ -34,6 +34,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
),
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"datasetio": "huggingface",
"inference": "together",
},
id="meta_reference_eval_together_inference_huggingface_datasetio",
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
),
]
@ -41,6 +51,7 @@ def pytest_configure(config):
for fixture_name in [
"meta_reference_eval_fireworks_inference",
"meta_reference_eval_together_inference",
"meta_reference_eval_together_inference_huggingface_datasetio",
]:
config.addinivalue_line(
"markers",

View file

@ -7,10 +7,15 @@
import pytest
from llama_models.llama3.api import SamplingParams
from llama_models.llama3.api import SamplingParams, URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider
from llama_stack.apis.eval.eval import (
AppEvalTaskConfig,
BenchmarkEvalTaskConfig,
EvalTaskDefWithProvider,
ModelCandidate,
)
@ -21,7 +26,7 @@ from llama_stack.providers.tests.datasetio.test_datasetio import register_datase
# How to run this test:
#
# pytest llama_stack/providers/tests/eval/test_eval.py
# -m "meta_reference"
# -m "meta_reference_eval_together_inference_huggingface_datasetio"
# -v -s --tb=short --disable-warnings
@ -33,21 +38,26 @@ class Testeval:
eval_tasks_impl = eval_stack[Api.eval_tasks]
response = await eval_tasks_impl.list_eval_tasks()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_eval_evaluate_rows(self, eval_stack):
eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = (
eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl, models_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasetio],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
response = await datasets_impl.list_datasets()
assert len(response) == 1
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset_for_eval",
rows_in_page=3,
@ -66,7 +76,6 @@ class Testeval:
provider_id="meta-reference",
)
await eval_tasks_impl.register_eval_task(task_def)
response = await eval_impl.evaluate_rows(
task_id=task_id,
input_rows=rows.rows,
@ -84,11 +93,17 @@ class Testeval:
@pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack):
eval_impl, eval_tasks_impl, datasets_impl = (
eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
@ -124,3 +139,72 @@ class Testeval:
assert len(eval_response.generations) == 5
assert "meta-reference::subset_of" in eval_response.scores
assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores
@pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack):
eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
response = await datasets_impl.list_datasets()
assert len(response) > 0
if response[0].provider_id != "huggingface":
pytest.skip(
"Only huggingface provider supports pre-registered remote datasets"
)
# register dataset
mmlu = DatasetDefWithProvider(
identifier="mmlu",
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
dataset_schema={
"input_query": StringType(),
"expected_answer": StringType(),
"chat_completion_input": ChatCompletionInputType(),
},
metadata={
"path": "llamastack/evals",
"name": "evals__mmlu__details",
"split": "train",
},
provider_id="",
)
await datasets_impl.register_dataset(mmlu)
# register eval task
meta_reference_mmlu = EvalTaskDefWithProvider(
identifier="meta-reference-mmlu",
dataset_id="mmlu",
scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"],
provider_id="",
)
await eval_tasks_impl.register_eval_task(meta_reference_mmlu)
# list benchmarks
response = await eval_tasks_impl.list_eval_tasks()
assert len(response) > 0
benchmark_id = "meta-reference-mmlu"
response = await eval_impl.run_eval(
task_id=benchmark_id,
task_config=BenchmarkEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(),
),
num_examples=3,
),
)
job_status = await eval_impl.job_status(benchmark_id, response.job_id)
assert job_status and job_status.value == "completed"
eval_response = await eval_impl.job_result(benchmark_id, response.job_id)
assert eval_response is not None
assert len(eval_response.generations) == 3

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import io
from urllib.parse import unquote
import pandas
from llama_models.llama3.api.datatypes import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url
def get_dataframe_from_url(url: URL):
df = None
if url.uri.endswith(".csv"):
df = pandas.read_csv(url.uri)
elif url.uri.endswith(".xlsx"):
df = pandas.read_excel(url.uri)
elif url.uri.startswith("data:"):
parts = parse_data_url(url.uri)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
data_bytes = io.BytesIO(data)
if mime_category == "text":
df = pandas.read_csv(data_bytes)
else:
df = pandas.read_excel(data_bytes)
else:
raise ValueError(f"Unsupported file type: {url}")
return df