diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py new file mode 100644 index 000000000..7d5f49c8a --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + RegexParserScoringFnParams, + ScoringFn, +) + +MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P.*)}\$"] + + +regex_parser_math_response = ScoringFn( + identifier="basic::regex_parser_math_response", + description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match", + return_type=NumberType(), + provider_id="basic", + provider_resource_id="regex-parser-matH-response", + params=RegexParserScoringFnParams( + parsing_regexes=MATH_ANSWER_REGEXES, + aggregation_functions=[AggregationFunctionType.accuracy], + ), +) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/math_exact_match_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/math_exact_match_fn.py deleted file mode 100644 index 30caecf23..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/math_exact_match_fn.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, Optional - -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from .fn_defs.equality import equality - - -# class EqualityScoringFn(RegisteredBaseScoringFn): -class MathExactMatchFn(RegisteredBaseScoringFn): - """ - A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = { - equality.identifier: equality, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = "equality", - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - assert "expected_answer" in input_row, "Expected answer not found in input row." - assert "generated_answer" in input_row, "Generated answer not found in input row." - - expected_answer = input_row["expected_answer"] - generated_answer = input_row["generated_answer"] - - - score = 1.0 if expected_answer == generated_answer else 0.0 - return { - "score": score, - } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/math_exact_match.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py similarity index 73% rename from llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/math_exact_match.py rename to llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py index a743d606a..35789fb2c 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/math_exact_match.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py @@ -13,7 +13,7 @@ from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseSc from .fn_defs.regex_parser_multiple_choice_answer import ( regex_parser_multiple_choice_answer, ) -from ...utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex +from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex class RegexParserScoringFn(RegisteredBaseScoringFn): @@ -45,25 +45,25 @@ class RegexParserScoringFn(RegisteredBaseScoringFn): expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] + parsing_regexes = fn_def.params.parsing_regexes - pattern = r".*final answer is:?\s*\$\\boxed{(?P.*)}\$" + assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function." + + parsing_regexes = fn_def.params.parsing_regexes[0] + # parsing_regexes = r".*final answer is:?\s*\$\\boxed{(?P.*)}\$" + normalized_generated_answer = normalize_final_answer( first_answer(generated_answer), - pattern, + parsing_regexes, match_first=True, ) normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer)) - # parse answer according to regex - parsed_answer = None - for regex in fn_def.params.parsing_regexes: - match = re.search(regex, generated_answer) - if match: - parsed_answer = match.group(1) - break + normalized_expected_answer = normalize_final_answer(expected_answer, r".*") + normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer)) - score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0 + score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0 return { "score": score, }