From 90ed992138e88612fccd9561162c4e5830cc2713 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Mon, 10 Mar 2025 00:38:16 -0700 Subject: [PATCH] apply fix --- .../regex_parser_math_response_scoring_fn.py | 54 ++++--------------- .../scoring_fn/regex_parser_scoring_fn.py | 1 - .../inline/scoring/basic/utils/math_utils.py | 20 ++++--- 3 files changed, 22 insertions(+), 53 deletions(-) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py index 6ad3813ca..d6c78a9ac 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py @@ -3,17 +3,16 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import re from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn +from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex from .fn_defs.regex_parser_math_response import ( regex_parser_math_response, ) -from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex, extract_result_from_boxed class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn): @@ -33,7 +32,6 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn): scoring_fn_identifier: Optional[str] = None, scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: - print('I reach RegexParserMathResponseScoringFn') assert scoring_fn_identifier is not None, "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] if scoring_params is not None: @@ -44,58 +42,24 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn): ) expected_answer = input_row["expected_answer"] - expected_answer = r""" - We have that $r = \sqrt{0^2 + 3^2} = 3.$ Also, if we draw the line connecting the origin and $(0,3),$ this line makes an angle of $\frac{\pi}{2}$ with the positive $x$-axis. - - [asy] - unitsize(0.8 cm); - - draw((-0.5,0)--(3.5,0)); - draw((0,-0.5)--(0,3.5)); - draw(arc((0,0),3,0,90),red,Arrow(6)); - - dot((0,3), red); - label("$(0,3)$", (0,3), W); - dot((3,0), red); - [/asy] - - Therefore, the polar coordinates are $\boxed{\left( 3, \frac{\pi}{2} \right)}.$ - """ generated_answer = input_row["generated_answer"] - print('expected_answer', expected_answer) - print('generated_answer', generated_answer) - parsing_regexes = fn_def.params.parsing_regexes - - assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function." - + assert len(parsing_regexes) == 1, ( + "Only one parsing regex is supported for regex_parser_math_response scoring function." + ) parsing_regexes = fn_def.params.parsing_regexes[0] - print('parsing_regexes', parsing_regexes) - - normalized_generated_answer = normalize_final_answer( - first_answer(generated_answer), - parsing_regexes, - match_first=True, - ) - print('normalized_generated_answer_1', normalized_generated_answer) - - + first_answer(generated_answer), + parsing_regexes, + match_first=True, + ) normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer)) - print('normalized_generated_answer_2', normalized_generated_answer) - - - # print('extract_result_from_boxed', extract_result_from_boxed(expected_answer)) - # normalized_expected_answer = normalize_final_answer(extract_result_from_boxed(expected_answer), r".*final answer is:?\s*\$\\boxed{(?P.*)}\$") - normalized_expected_answer = normalize_final_answer(expected_answer, r"\$\\boxed{(?P.*)}\$") - print('normalized_expected_answer_1', normalized_expected_answer) + normalized_expected_answer = normalize_final_answer(expected_answer, r".*") normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer)) - print('normalized_expected_answer_2', normalized_expected_answer) - score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0 return { "score": score, diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index bc105ebfb..0606a9581 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -32,7 +32,6 @@ class RegexParserScoringFn(RegisteredBaseScoringFn): scoring_fn_identifier: Optional[str] = None, scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: - print("I reach RegexParserScoringFn") assert scoring_fn_identifier is not None, "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] if scoring_params is not None: diff --git a/llama_stack/providers/inline/scoring/basic/utils/math_utils.py b/llama_stack/providers/inline/scoring/basic/utils/math_utils.py index 120f6fdc4..f108886a2 100644 --- a/llama_stack/providers/inline/scoring/basic/utils/math_utils.py +++ b/llama_stack/providers/inline/scoring/basic/utils/math_utils.py @@ -6,14 +6,15 @@ import contextlib import re -from types import FrameType -from typing import Sequence, Iterator, Optional -import sympy import signal +from types import FrameType +from typing import Iterator, Optional, Sequence + class TimeoutException(Exception): pass + @contextlib.contextmanager def time_limit(seconds: float) -> Iterator[None]: def signal_handler(signum: int, frame: Optional[FrameType]) -> None: @@ -26,6 +27,7 @@ def time_limit(seconds: float) -> Iterator[None]: finally: signal.setitimer(signal.ITIMER_REAL, 0) + # from minerva SUBSTITUTIONS = [ ("an ", ""), @@ -103,6 +105,7 @@ def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str: continue return new_expression + def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str: try: with time_limit(seconds=5): @@ -114,12 +117,12 @@ def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str: return expression - def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str: for marker in markers: text = text.split(marker)[0] return text + def extract_result_from_boxed(answer: str) -> str: box_start = "\\boxed" # format is `\\boxed $` or `\\boxed{}`, with potential white spaces framing `` @@ -155,9 +158,7 @@ def extract_result_from_boxed(answer: str) -> str: # from minerva paper + _normalise_result from xavierm -def normalize_final_answer( - final_answer: str, regex_pattern: str, match_first: bool = True -) -> str: +def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str: """Extract and normalize a final answer to a quantitative reasoning question.""" match = re.findall(regex_pattern, final_answer) extraction: str @@ -203,6 +204,7 @@ def normalize_final_answer( final_answer = final_answer[1] return _normalise_result(final_answer) + def _normalise_result(string: str) -> str: # linebreaks string = string.replace("\n", "") @@ -220,6 +222,7 @@ def _normalise_result(string: str) -> str: # remove \left and \right string = string.replace("\\left", "") + string = string.replace("\\le", "") string = string.replace("\\right", "") # Remove circ (degrees) @@ -266,6 +269,7 @@ def _normalise_result(string: str) -> str: return string + def _remove_right_units(string: str) -> str: # "\\text{ " only ever occurs (at least in the val set) when describing units try: @@ -278,6 +282,7 @@ def _remove_right_units(string: str) -> str: except AssertionError: return string + def _fix_sqrt(string: str) -> str: if "\\sqrt" not in string: return string @@ -328,6 +333,7 @@ def _fix_fracs(string: str) -> str: string = new_str return string + def _fix_a_slash_b(string: str) -> str: if len(string.split("/")) != 2: return string