apply fix

This commit is contained in:
Botao Chen 2025-03-10 00:38:16 -07:00
parent 10a9f6a5c8
commit 90ed992138
3 changed files with 22 additions and 53 deletions

View file

@ -3,17 +3,16 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
from .fn_defs.regex_parser_math_response import ( from .fn_defs.regex_parser_math_response import (
regex_parser_math_response, regex_parser_math_response,
) )
from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex, extract_result_from_boxed
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn): class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
@ -33,7 +32,6 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = None, scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None, scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
print('I reach RegexParserMathResponseScoringFn')
assert scoring_fn_identifier is not None, "Scoring function identifier not found." assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None: if scoring_params is not None:
@ -44,58 +42,24 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
) )
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
expected_answer = r"""
We have that $r = \sqrt{0^2 + 3^2} = 3.$ Also, if we draw the line connecting the origin and $(0,3),$ this line makes an angle of $\frac{\pi}{2}$ with the positive $x$-axis.
[asy]
unitsize(0.8 cm);
draw((-0.5,0)--(3.5,0));
draw((0,-0.5)--(0,3.5));
draw(arc((0,0),3,0,90),red,Arrow(6));
dot((0,3), red);
label("$(0,3)$", (0,3), W);
dot((3,0), red);
[/asy]
Therefore, the polar coordinates are $\boxed{\left( 3, \frac{\pi}{2} \right)}.$
"""
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]
print('expected_answer', expected_answer)
print('generated_answer', generated_answer)
parsing_regexes = fn_def.params.parsing_regexes parsing_regexes = fn_def.params.parsing_regexes
assert len(parsing_regexes) == 1, (
assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function." "Only one parsing regex is supported for regex_parser_math_response scoring function."
)
parsing_regexes = fn_def.params.parsing_regexes[0] parsing_regexes = fn_def.params.parsing_regexes[0]
print('parsing_regexes', parsing_regexes)
normalized_generated_answer = normalize_final_answer( normalized_generated_answer = normalize_final_answer(
first_answer(generated_answer), first_answer(generated_answer),
parsing_regexes, parsing_regexes,
match_first=True, match_first=True,
) )
print('normalized_generated_answer_1', normalized_generated_answer)
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer)) normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
print('normalized_generated_answer_2', normalized_generated_answer) normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
# print('extract_result_from_boxed', extract_result_from_boxed(expected_answer))
# normalized_expected_answer = normalize_final_answer(extract_result_from_boxed(expected_answer), r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$")
normalized_expected_answer = normalize_final_answer(expected_answer, r"\$\\boxed{(?P<X>.*)}\$")
print('normalized_expected_answer_1', normalized_expected_answer)
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer)) normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
print('normalized_expected_answer_2', normalized_expected_answer)
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0 score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
return { return {
"score": score, "score": score,

View file

@ -32,7 +32,6 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = None, scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None, scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
print("I reach RegexParserScoringFn")
assert scoring_fn_identifier is not None, "Scoring function identifier not found." assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None: if scoring_params is not None:

View file

@ -6,14 +6,15 @@
import contextlib import contextlib
import re import re
from types import FrameType
from typing import Sequence, Iterator, Optional
import sympy
import signal import signal
from types import FrameType
from typing import Iterator, Optional, Sequence
class TimeoutException(Exception): class TimeoutException(Exception):
pass pass
@contextlib.contextmanager @contextlib.contextmanager
def time_limit(seconds: float) -> Iterator[None]: def time_limit(seconds: float) -> Iterator[None]:
def signal_handler(signum: int, frame: Optional[FrameType]) -> None: def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
@ -26,6 +27,7 @@ def time_limit(seconds: float) -> Iterator[None]:
finally: finally:
signal.setitimer(signal.ITIMER_REAL, 0) signal.setitimer(signal.ITIMER_REAL, 0)
# from minerva # from minerva
SUBSTITUTIONS = [ SUBSTITUTIONS = [
("an ", ""), ("an ", ""),
@ -103,6 +105,7 @@ def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
continue continue
return new_expression return new_expression
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str: def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
try: try:
with time_limit(seconds=5): with time_limit(seconds=5):
@ -114,12 +117,12 @@ def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
return expression return expression
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str: def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
for marker in markers: for marker in markers:
text = text.split(marker)[0] text = text.split(marker)[0]
return text return text
def extract_result_from_boxed(answer: str) -> str: def extract_result_from_boxed(answer: str) -> str:
box_start = "\\boxed" box_start = "\\boxed"
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>` # format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
@ -155,9 +158,7 @@ def extract_result_from_boxed(answer: str) -> str:
# from minerva paper + _normalise_result from xavierm # from minerva paper + _normalise_result from xavierm
def normalize_final_answer( def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
final_answer: str, regex_pattern: str, match_first: bool = True
) -> str:
"""Extract and normalize a final answer to a quantitative reasoning question.""" """Extract and normalize a final answer to a quantitative reasoning question."""
match = re.findall(regex_pattern, final_answer) match = re.findall(regex_pattern, final_answer)
extraction: str extraction: str
@ -203,6 +204,7 @@ def normalize_final_answer(
final_answer = final_answer[1] final_answer = final_answer[1]
return _normalise_result(final_answer) return _normalise_result(final_answer)
def _normalise_result(string: str) -> str: def _normalise_result(string: str) -> str:
# linebreaks # linebreaks
string = string.replace("\n", "") string = string.replace("\n", "")
@ -220,6 +222,7 @@ def _normalise_result(string: str) -> str:
# remove \left and \right # remove \left and \right
string = string.replace("\\left", "") string = string.replace("\\left", "")
string = string.replace("\\le", "")
string = string.replace("\\right", "") string = string.replace("\\right", "")
# Remove circ (degrees) # Remove circ (degrees)
@ -266,6 +269,7 @@ def _normalise_result(string: str) -> str:
return string return string
def _remove_right_units(string: str) -> str: def _remove_right_units(string: str) -> str:
# "\\text{ " only ever occurs (at least in the val set) when describing units # "\\text{ " only ever occurs (at least in the val set) when describing units
try: try:
@ -278,6 +282,7 @@ def _remove_right_units(string: str) -> str:
except AssertionError: except AssertionError:
return string return string
def _fix_sqrt(string: str) -> str: def _fix_sqrt(string: str) -> str:
if "\\sqrt" not in string: if "\\sqrt" not in string:
return string return string
@ -328,6 +333,7 @@ def _fix_fracs(string: str) -> str:
string = new_str string = new_str
return string return string
def _fix_a_slash_b(string: str) -> str: def _fix_a_slash_b(string: str) -> str:
if len(string.split("/")) != 2: if len(string.split("/")) != 2:
return string return string