mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
apply fix
This commit is contained in:
parent
10a9f6a5c8
commit
90ed992138
3 changed files with 22 additions and 53 deletions
|
@ -3,17 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
|
||||
from .fn_defs.regex_parser_math_response import (
|
||||
regex_parser_math_response,
|
||||
)
|
||||
from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex, extract_result_from_boxed
|
||||
|
||||
|
||||
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
||||
|
@ -33,7 +32,6 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
|||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
print('I reach RegexParserMathResponseScoringFn')
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
|
@ -44,58 +42,24 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
|||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
expected_answer = r"""
|
||||
We have that $r = \sqrt{0^2 + 3^2} = 3.$ Also, if we draw the line connecting the origin and $(0,3),$ this line makes an angle of $\frac{\pi}{2}$ with the positive $x$-axis.
|
||||
|
||||
[asy]
|
||||
unitsize(0.8 cm);
|
||||
|
||||
draw((-0.5,0)--(3.5,0));
|
||||
draw((0,-0.5)--(0,3.5));
|
||||
draw(arc((0,0),3,0,90),red,Arrow(6));
|
||||
|
||||
dot((0,3), red);
|
||||
label("$(0,3)$", (0,3), W);
|
||||
dot((3,0), red);
|
||||
[/asy]
|
||||
|
||||
Therefore, the polar coordinates are $\boxed{\left( 3, \frac{\pi}{2} \right)}.$
|
||||
"""
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
print('expected_answer', expected_answer)
|
||||
print('generated_answer', generated_answer)
|
||||
|
||||
parsing_regexes = fn_def.params.parsing_regexes
|
||||
|
||||
assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||
|
||||
assert len(parsing_regexes) == 1, (
|
||||
"Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||
)
|
||||
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||
|
||||
print('parsing_regexes', parsing_regexes)
|
||||
|
||||
|
||||
normalized_generated_answer = normalize_final_answer(
|
||||
first_answer(generated_answer),
|
||||
parsing_regexes,
|
||||
match_first=True,
|
||||
)
|
||||
print('normalized_generated_answer_1', normalized_generated_answer)
|
||||
|
||||
|
||||
first_answer(generated_answer),
|
||||
parsing_regexes,
|
||||
match_first=True,
|
||||
)
|
||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||
|
||||
print('normalized_generated_answer_2', normalized_generated_answer)
|
||||
|
||||
|
||||
# print('extract_result_from_boxed', extract_result_from_boxed(expected_answer))
|
||||
# normalized_expected_answer = normalize_final_answer(extract_result_from_boxed(expected_answer), r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$")
|
||||
normalized_expected_answer = normalize_final_answer(expected_answer, r"\$\\boxed{(?P<X>.*)}\$")
|
||||
print('normalized_expected_answer_1', normalized_expected_answer)
|
||||
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
||||
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||
|
||||
print('normalized_expected_answer_2', normalized_expected_answer)
|
||||
|
||||
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
|
|
|
@ -32,7 +32,6 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
|||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
print("I reach RegexParserScoringFn")
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
|
|
|
@ -6,14 +6,15 @@
|
|||
|
||||
import contextlib
|
||||
import re
|
||||
from types import FrameType
|
||||
from typing import Sequence, Iterator, Optional
|
||||
import sympy
|
||||
import signal
|
||||
from types import FrameType
|
||||
from typing import Iterator, Optional, Sequence
|
||||
|
||||
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def time_limit(seconds: float) -> Iterator[None]:
|
||||
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
|
||||
|
@ -26,6 +27,7 @@ def time_limit(seconds: float) -> Iterator[None]:
|
|||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||
|
||||
|
||||
# from minerva
|
||||
SUBSTITUTIONS = [
|
||||
("an ", ""),
|
||||
|
@ -103,6 +105,7 @@ def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
|
|||
continue
|
||||
return new_expression
|
||||
|
||||
|
||||
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
|
||||
try:
|
||||
with time_limit(seconds=5):
|
||||
|
@ -114,12 +117,12 @@ def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
|
|||
return expression
|
||||
|
||||
|
||||
|
||||
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
|
||||
for marker in markers:
|
||||
text = text.split(marker)[0]
|
||||
return text
|
||||
|
||||
|
||||
def extract_result_from_boxed(answer: str) -> str:
|
||||
box_start = "\\boxed"
|
||||
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
|
||||
|
@ -155,9 +158,7 @@ def extract_result_from_boxed(answer: str) -> str:
|
|||
|
||||
|
||||
# from minerva paper + _normalise_result from xavierm
|
||||
def normalize_final_answer(
|
||||
final_answer: str, regex_pattern: str, match_first: bool = True
|
||||
) -> str:
|
||||
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
|
||||
"""Extract and normalize a final answer to a quantitative reasoning question."""
|
||||
match = re.findall(regex_pattern, final_answer)
|
||||
extraction: str
|
||||
|
@ -203,6 +204,7 @@ def normalize_final_answer(
|
|||
final_answer = final_answer[1]
|
||||
return _normalise_result(final_answer)
|
||||
|
||||
|
||||
def _normalise_result(string: str) -> str:
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
@ -220,6 +222,7 @@ def _normalise_result(string: str) -> str:
|
|||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\le", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
|
@ -266,6 +269,7 @@ def _normalise_result(string: str) -> str:
|
|||
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string: str) -> str:
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
try:
|
||||
|
@ -278,6 +282,7 @@ def _remove_right_units(string: str) -> str:
|
|||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string: str) -> str:
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
|
@ -328,6 +333,7 @@ def _fix_fracs(string: str) -> str:
|
|||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string: str) -> str:
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue