Folder restructure for evals/datasets/scoring (#419)

* rename evals related stuff

* fix datasetio

* fix scoring test

* localfs -> LocalFS

* refactor scoring

* refactor scoring

* remove 8b_correctness scoring_fn from tests

* tests w/ eval params

* scoring fn braintrust fixture

* import
This commit is contained in:
Xi Yan 2024-11-11 17:35:40 -05:00 committed by GitHub
parent 2b7d70ba86
commit b4416b72fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 141 additions and 100 deletions

View file

@ -4,15 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceDatasetIOConfig
from .config import LocalFSDatasetIOConfig
async def get_provider_impl(
config: MetaReferenceDatasetIOConfig,
config: LocalFSDatasetIOConfig,
_deps,
):
from .datasetio import MetaReferenceDatasetIOImpl
from .datasetio import LocalFSDatasetIOImpl
impl = MetaReferenceDatasetIOImpl(config)
impl = LocalFSDatasetIOImpl(config)
await impl.initialize()
return impl

View file

@ -6,4 +6,4 @@
from llama_stack.apis.datasetio import * # noqa: F401, F403
class MetaReferenceDatasetIOConfig(BaseModel): ...
class LocalFSDatasetIOConfig(BaseModel): ...

View file

@ -15,7 +15,7 @@ from dataclasses import dataclass
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from .config import MetaReferenceDatasetIOConfig
from .config import LocalFSDatasetIOConfig
class BaseDataset(ABC):
@ -77,8 +77,8 @@ class PandasDataframeDataset(BaseDataset):
self.df = self._validate_dataset_schema(df)
class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: MetaReferenceDatasetIOConfig) -> None:
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos = {}

View file

@ -9,14 +9,13 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from tqdm import tqdm
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from tqdm import tqdm
from .config import MetaReferenceEvalConfig

View file

@ -1,31 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow
FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
return {
"average": sum(
result["score"] for result in scoring_results if result["score"] is not None
)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}

View file

@ -1,39 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
JUDGE_PROMPT = """
You will be given a question, a expected_answer, and a system_answer.
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
Provide your feedback as follows:
Feedback:::
Total rating: (your rating, as a int between 0 and 5)
Now here are the question, expected_answer, system_answer.
Question: {input_query}
Expected Answer: {expected_answer}
System Answer: {generated_answer}
Feedback:::
Total rating:
"""
llm_as_judge_8b_correctness = ScoringFnDef(
identifier="meta-reference::llm_as_judge_8b_correctness",
description="Llm As Judge Scoring Function",
return_type=NumberType(),
params=LLMAsJudgeScoringFnParams(
prompt_template=JUDGE_PROMPT,
judge_model="Llama3.1-8B-Instruct",
judge_score_regexes=[
r"Total rating: (\d+)",
r"rating: (\d+)",
r"Rating: (\d+)",
],
),
)

View file

@ -16,9 +16,8 @@ from llama_stack.apis.datasets import * # noqa: F403
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
aggregate_average,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def

View file

@ -4,20 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.equality import (
equality,
)
from .fn_defs.equality import equality
class EqualityScoringFn(BaseScoringFn):

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
llm_as_judge_base = ScoringFnDef(
identifier="meta-reference::llm_as_judge_base",
description="Llm As Judge Scoring Function",
return_type=NumberType(),
)

View file

@ -4,20 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
import re
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
aggregate_average,
)
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
llm_as_judge_8b_correctness,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
from .fn_defs.llm_as_judge_base import llm_as_judge_base
class LlmAsJudgeScoringFn(BaseScoringFn):
@ -29,7 +25,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
super().__init__(*arg, **kwargs)
self.inference_api = inference_api
self.supported_fn_defs_registry = {
llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness,
llm_as_judge_base.identifier: llm_as_judge_base,
}
async def score_row(

View file

@ -9,7 +9,7 @@ from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from .common import aggregate_accuracy
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,

View file

@ -4,19 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
subset_of,
)
from .fn_defs.subset_of import subset_of
class SubsetOfScoringFn(BaseScoringFn):