Folder restructure for evals/datasets/scoring (#419)

* rename evals related stuff

* fix datasetio

* fix scoring test

* localfs -> LocalFS

* refactor scoring

* refactor scoring

* remove 8b_correctness scoring_fn from tests

* tests w/ eval params

* scoring fn braintrust fixture

* import
This commit is contained in:
Xi Yan 2024-11-11 17:35:40 -05:00 committed by GitHub
parent 2b7d70ba86
commit b4416b72fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 141 additions and 100 deletions

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import BraintrustScoringConfig
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],
):
from .braintrust import BraintrustScoringImpl
impl = BraintrustScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
await impl.initialize()
return impl

View file

@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self,
config: BraintrustScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.braintrust_evaluators = {
"braintrust::factuality": Factuality(),
"braintrust::answer-correctness": AnswerCorrectness(),
}
self.supported_fn_defs_registry = {
factuality_fn_def.identifier: factuality_fn_def,
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"braintrust"
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
raise NotImplementedError(
"Registering scoring function not allowed for braintrust provider"
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
input_query = input_row["input_query"]
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
result = evaluator(generated_answer, expected_answer, input=input_query)
score = result.score
return {"score": score, "metadata": result.metadata}
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = [
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
agg_results = aggregate_average(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403
class BraintrustScoringConfig(BaseModel): ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
answer_correctness_fn_def = ScoringFnDef(
identifier="braintrust::answer-correctness",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
parameters=[],
return_type=NumberType(),
)

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
factuality_fn_def = ScoringFnDef(
identifier="braintrust::factuality",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
parameters=[],
return_type=NumberType(),
)

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceScoringConfig
async def get_provider_impl(
config: MetaReferenceScoringConfig,
deps: Dict[Api, ProviderSpec],
):
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
)
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403
class MetaReferenceScoringConfig(BaseModel): ...

View file

@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from .config import MetaReferenceScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self,
config: MetaReferenceScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
inference_api: Inference,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None:
for x in FIXED_FNS:
impl = x()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
for x in LLM_JUDGE_FNS:
impl = x(inference_api=self.inference_api)
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"meta-reference"
), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! "
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
class BaseScoringFn(ABC):
"""
Base interface class for all meta-reference scoring_fns.
Each scoring_fn needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scoring_fn_results)
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {}
def __str__(self) -> str:
return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
return [x for x in self.supported_fn_defs_registry.values()]
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
raise ValueError(
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
)
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
@abstractmethod
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
raise NotImplementedError()
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]:
return [
await self.score_row(input_row, scoring_fn_identifier, scoring_params)
for input_row in input_rows
]

View file

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.equality import equality
class EqualityScoringFn(BaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
equality.identifier: equality,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer == generated_answer else 0.0
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
equality = ScoringFnDef(
identifier="meta-reference::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
return_type=NumberType(),
)

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
llm_as_judge_base = ScoringFnDef(
identifier="meta-reference::llm_as_judge_base",
description="Llm As Judge Scoring Function",
return_type=NumberType(),
)

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:",
r"Answer\s*:", # Korean invisible character
r"উত্তর\s*:",
r"उत्तर\s*:",
r"উত্তরঃ",
r"উত্তর\s*:",
r"Antwort\s*:",
r"답변\s*:",
r"정답\s*:",
r"\s*:",
r"答案\s*",
r"答案\s*:",
r"\s*",
r"\s*:",
r"答复\s*",
r"答曰\s*",
r"الإجابة:",
r"الجواب:",
r"إجابة:",
r"الإجابة النهائية:",
r"الإجابة الصحيحة:",
r"الإجابة الصحيحة هي:",
r"الإجابة هي:",
r"Respuesta\s*:",
r"Risposta\s*:",
r"答え\s*:",
r"答え\s*",
r"回答\s*:",
r"回答\s*",
r"解答\s*:",
r"Jawaban\s*:",
r"Réponse\s*:",
r"Resposta\s*:",
r"Jibu\s*:",
r"Idahun\s*:",
r"Ìdáhùn\s*:",
r"Idáhùn\s*:",
r"Àmọ̀nà\s*:",
r"Àdáhùn\s*:",
r"Ànúgọ\s*:",
r"Àṣàyàn\s*:",
]
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
)
regex_parser_multiple_choice_answer = ScoringFnDef(
identifier="meta-reference::regex_parser_multiple_choice_answer",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(),
params=RegexParserScoringFnParams(
parsing_regexes=[
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
),
)

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
subset_of = ScoringFnDef(
identifier="meta-reference::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference.inference import Inference
from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
import re
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
from .fn_defs.llm_as_judge_base import llm_as_judge_base
class LlmAsJudgeScoringFn(BaseScoringFn):
"""
A scoring_fn that assigns
"""
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.inference_api = inference_api
self.supported_fn_defs_registry = {
llm_as_judge_base.identifier: llm_as_judge_base,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
# override params if scoring_params is provided
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert (
fn_def.params.prompt_template is not None
), "LLM Judge prompt_template not found."
assert (
fn_def.params.judge_score_regexes is not None
), "LLM Judge judge_score_regexes not found."
input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
judge_input_msg = fn_def.params.prompt_template.format(
input_query=input_query,
expected_answer=expected_answer,
generated_answer=generated_answer,
)
judge_response = await self.inference_api.chat_completion(
model=fn_def.params.judge_model,
messages=[
{
"role": "user",
"content": judge_input_msg,
}
],
)
content = judge_response.completion_message.content
rating_regexes = fn_def.params.judge_score_regexes
judge_rating = None
for regex in rating_regexes:
match = re.search(regex, content)
if match:
judge_rating = int(match.group(1))
break
return {
"score": judge_rating,
"judge_feedback": content,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_average(scoring_results)

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)
class RegexParserScoringFn(BaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert (
fn_def.params is not None
and fn_def.params.type == ScoringConfigType.regex_parser.value
), f"RegexParserScoringFnParams not found for {fn_def}."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
# parse answer according to regex
parsed_answer = None
for regex in fn_def.params.parsing_regexes:
match = re.search(regex, generated_answer)
if match:
parsed_answer = match.group(1)
break
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.subset_of import subset_of
class SubsetOfScoringFn(BaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
subset_of.identifier: subset_of,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer in generated_answer else 0.0
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)