temp commit

This commit is contained in:
Botao Chen 2025-03-07 16:58:07 -08:00
parent b99da3b9e4
commit 6555e2136f
3 changed files with 38 additions and 55 deletions

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
regex_parser_math_response = ScoringFn(
identifier="basic::regex_parser_math_response",
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="regex-parser-matH-response",
params=RegexParserScoringFnParams(
parsing_regexes=MATH_ANSWER_REGEXES,
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -1,44 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.equality import equality
# class EqualityScoringFn(RegisteredBaseScoringFn):
class MathExactMatchFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
equality.identifier: equality,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert "generated_answer" in input_row, "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer == generated_answer else 0.0
return {
"score": score,
}

View file

@ -13,7 +13,7 @@ from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseSc
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)
from ...utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex
from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex
class RegexParserScoringFn(RegisteredBaseScoringFn):
@ -45,25 +45,25 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
parsing_regexes = fn_def.params.parsing_regexes
pattern = r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"
assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function."
parsing_regexes = fn_def.params.parsing_regexes[0]
# parsing_regexes = r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"
normalized_generated_answer = normalize_final_answer(
first_answer(generated_answer),
pattern,
parsing_regexes,
match_first=True,
)
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
# parse answer according to regex
parsed_answer = None
for regex in fn_def.params.parsing_regexes:
match = re.search(regex, generated_answer)
if match:
parsed_answer = match.group(1)
break
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
return {
"score": score,
}