[Evals API][3/n] scoring_functions / scoring meta-reference implementations (#296)

* wip

* dataset validation

* test_scoring

* cleanup

* clean up test

* comments

* error checking

* dataset client

* test client:

* datasetio client

* clean up

* basic scoring function works

* scorer wip

* equality scorer

* score batch impl

* score batch

* update scoring test

* refactor

* validate scorer input

* address comments

* add all rows scores to ScoringResult

* bugfix

* scoring function def rename
This commit is contained in:
Xi Yan 2024-10-24 14:52:30 -07:00 committed by GitHub
parent e70420a06e
commit cb84034567
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 904 additions and 51 deletions

View file

@ -3,17 +3,20 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
from typing import List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64
from abc import ABC, abstractmethod
from dataclasses import dataclass
from urllib.parse import unquote
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import MetaReferenceDatasetIOConfig
@ -52,11 +55,20 @@ class PandasDataframeDataset(BaseDataset):
return len(self.df)
def __getitem__(self, idx):
assert self.df is not None, "Dataset not loaded. Please call .load() first"
if isinstance(idx, slice):
return self.df.iloc[idx].to_dict(orient="records")
else:
return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None:
if self.df is not None:
return
@ -87,7 +99,7 @@ class PandasDataframeDataset(BaseDataset):
else:
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
self.df = df
self.df = self._validate_dataset_schema(df)
class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -123,7 +135,10 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
if page_token is None:
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceScoringConfig
async def get_provider_impl(
config: MetaReferenceScoringConfig,
deps: Dict[Api, ProviderSpec],
):
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403
class MetaReferenceScoringConfig(BaseModel): ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
class BaseScorer(ABC):
"""
Base interface class for all meta-reference scorers.
Each scorer needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scorer_results)
"""
scoring_function_def: ScoringFunctionDef
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __str__(self) -> str:
return self.__class__.__name__
@abstractmethod
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
raise NotImplementedError()
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]:
return [self.score_row(input_row) for input_row in input_rows]

View file

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import (
BaseScorer,
)
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
class EqualityScorer(BaseScorer):
"""
A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
scoring_function_def = ScoringFunctionDef(
identifier="equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer == generated_answer else 0.0
return {
"score": score,
}
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
assert len(scoring_results) > 0, "Empty scoring results provided."
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}

View file

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
EqualityScorer,
)
from .config import MetaReferenceScoringConfig
SUPPORTED_SCORERS = [
EqualityScorer,
]
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self,
config: MetaReferenceScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFunctionDef]:
return [x.scoring_function_def for x in SUPPORTED_SCORERS]
async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None:
raise NotImplementedError(
"Dynamically registering scoring functions is not supported"
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in SCORER_REGISTRY:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scorer = SCORER_REGISTRY[scoring_fn_id]()
score_results = scorer.score(input_rows)
agg_results = scorer.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)