From e3edca77391ebc73af7fad0ea4b5d4132961c067 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Mon, 10 Mar 2025 20:38:28 -0700 Subject: [PATCH] feat: [new open benchmark] Math 500 (#1538) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? Created a new math_500 open-benchmark based on OpenAI's [Let's Verify Step by Step](https://arxiv.org/abs/2305.20050) paper and hugging face's [HuggingFaceH4/MATH-500](https://huggingface.co/datasets/HuggingFaceH4/MATH-500) dataset. The challenge part of this benchmark is to parse the generated and expected answer and verify if they are same. For the parsing part, we refer to [Minerva: Solving Quantitative Reasoning Problems with Language Models](https://research.google/blog/minerva-solving-quantitative-reasoning-problems-with-language-models/). To simply the parse logic, as the next step, we plan to also refer to what [simple-eval](https://github.com/openai/simple-evals) is doing, using llm as judge to check if the generated answer matches the expected answer or not ## Test Plan on sever side, spin up a server with open-benchmark template `llama stack run llama_stack/templates/open-benchamrk/run.yaml` on client side, issue an open benchmark eval request `llama-stack-client --endpoint xxx eval run-benchmark "meta-reference-math-500" --model-id "meta-llama/Llama-3.3-70B-Instruct" --output-dir "/home/markchen1015/" --num-examples 20` and get ther aggregated eval results Screenshot 2025-03-10 at 7 57 04 PM check the generated answer and the related scoring and they make sense --- .../providers/inline/scoring/basic/scoring.py | 3 +- .../fn_defs/regex_parser_math_response.py | 27 ++ .../regex_parser_math_response_scoring_fn.py | 66 ++++ .../inline/scoring/basic/utils/math_utils.py | 330 ++++++++++++++++++ .../utils/scoring/basic_scoring_utils.py | 26 ++ llama_stack/templates/open-benchmark/run.yaml | 20 +- 6 files changed, 470 insertions(+), 2 deletions(-) create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py create mode 100644 llama_stack/providers/inline/scoring/basic/utils/math_utils.py create mode 100644 llama_stack/providers/utils/scoring/basic_scoring_utils.py diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 13cd78243..00945b99d 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -23,10 +23,11 @@ from llama_stack.providers.utils.common.data_schema_validator import ( from .config import BasicScoringConfig from .scoring_fn.equality_scoring_fn import EqualityScoringFn +from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn -FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] +FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn] class BasicScoringImpl( diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py new file mode 100644 index 000000000..8b1bf5352 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + RegexParserScoringFnParams, + ScoringFn, +) + +MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P.*)}\$"] + + +regex_parser_math_response = ScoringFn( + identifier="basic::regex_parser_math_response", + description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match", + return_type=NumberType(), + provider_id="basic", + provider_resource_id="regex-parser-math-response", + params=RegexParserScoringFnParams( + parsing_regexes=MATH_ANSWER_REGEXES, + aggregation_functions=[AggregationFunctionType.accuracy], + ), +) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py new file mode 100644 index 000000000..d6c78a9ac --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, Optional + +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn + +from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex +from .fn_defs.regex_parser_math_response import ( + regex_parser_math_response, +) + + +class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn): + """ + A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = { + regex_parser_math_response.identifier: regex_parser_math_response, + } + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> ScoringResultRow: + assert scoring_fn_identifier is not None, "Scoring function identifier not found." + fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + if scoring_params is not None: + fn_def.params = scoring_params + + assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, ( + f"RegexParserScoringFnParams not found for {fn_def}." + ) + + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + + parsing_regexes = fn_def.params.parsing_regexes + assert len(parsing_regexes) == 1, ( + "Only one parsing regex is supported for regex_parser_math_response scoring function." + ) + parsing_regexes = fn_def.params.parsing_regexes[0] + + normalized_generated_answer = normalize_final_answer( + first_answer(generated_answer), + parsing_regexes, + match_first=True, + ) + normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer)) + + normalized_expected_answer = normalize_final_answer(expected_answer, r".*") + normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer)) + + score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0 + return { + "score": score, + } diff --git a/llama_stack/providers/inline/scoring/basic/utils/math_utils.py b/llama_stack/providers/inline/scoring/basic/utils/math_utils.py new file mode 100644 index 000000000..e11fc625b --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/utils/math_utils.py @@ -0,0 +1,330 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import re +from typing import Sequence + +from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit + +# from minerva +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "ft", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str: + if isinstance(expression, float): + return expression + new_expression = f"{expression}" + regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}") + for match in re.finditer(regex, expression): + try: + value = float(match.group(1)) / float(match.group(2)) + new_expression = new_expression.replace( + match.group(), + f"{{value:{fmt}}}".format(value=value), + 1, + ) + except Exception: + continue + return new_expression + + +def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str: + try: + with time_limit(seconds=5): + from sympy.parsing.latex import parse_latex + + value = parse_latex(expression).evalf() # type: ignore + return f"{{value:{fmt}}}".format(value=value) + except Exception: + return expression + + +def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str: + for marker in markers: + text = text.split(marker)[0] + return text + + +def extract_result_from_boxed(answer: str) -> str: + box_start = "\\boxed" + # format is `\\boxed $` or `\\boxed{}`, with potential white spaces framing `` + start = answer.rfind(box_start) + if start < 0: + return "" + answer = answer[start + len(box_start) :].strip() + ends_with_curly = answer.startswith("{") + i = 0 + open_braces = 0 + while i < len(answer): + if answer[i] == "{": + open_braces += 1 + elif answer[i] == "}": + open_braces -= 1 + if open_braces == 0: + if ends_with_curly: + answer = answer[: i + 1].strip() + break + elif answer[i] == "$": + answer = answer[:i].strip() + break + i += 1 + else: + return "" + # remove extra curly braces + while True: + if answer.startswith("{") and answer.endswith("}"): + answer = answer[1:-1].strip() + else: + break + return answer + + +# from minerva paper + _normalise_result from xavierm +def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str: + """Extract and normalize a final answer to a quantitative reasoning question.""" + match = re.findall(regex_pattern, final_answer) + extraction: str + if len(match) > 0: + if match_first: + extraction = match[0] + else: + extraction = match[-1] + else: + extraction = extract_result_from_boxed(final_answer) + + if len(extraction) == 0: + return final_answer + else: + final_answer = extraction + final_answer = final_answer.split("=")[-1] + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + # Extract answer that is in LaTeX math, is bold, + # is surrounded by a box, etc. + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + # Normalize 100,000 -> 100000 + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + # If the final answer is a single letter in parentheses, remove the parentheses + # Example: (a) -> a (but not (ab) -> ab) + if re.match(r"\([a-zA-Z]\)", final_answer): + final_answer = final_answer[1] + return _normalise_result(final_answer) + + +def _normalise_result(string: str) -> str: + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("cfrac", "frac") + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\le", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = _remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace(r"\%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + string = string.split("=")[-1] + + # fix sqrt3 --> sqrt{3} + string = _fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +def _remove_right_units(string: str) -> str: + # "\\text{ " only ever occurs (at least in the val set) when describing units + try: + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + except AssertionError: + return string + + +def _fix_sqrt(string: str) -> str: + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if len(split) == 0: + return string + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def _fix_fracs(string: str) -> str: + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if len(substr) == 0: + return string + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string: str) -> str: + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + ia = int(a) + ib = int(b) + assert string == "{}/{}".format(ia, ib) + new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}" + return new_string + except (ValueError, AssertionError): + return string diff --git a/llama_stack/providers/utils/scoring/basic_scoring_utils.py b/llama_stack/providers/utils/scoring/basic_scoring_utils.py new file mode 100644 index 000000000..91abfdb2e --- /dev/null +++ b/llama_stack/providers/utils/scoring/basic_scoring_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import contextlib +import signal +from types import FrameType +from typing import Iterator, Optional + + +class TimeoutError(Exception): + pass + + +@contextlib.contextmanager +def time_limit(seconds: float) -> Iterator[None]: + def signal_handler(signum: int, frame: Optional[FrameType]) -> None: + raise TimeoutError("Timed out!") + + signal.setitimer(signal.ITIMER_REAL, seconds) + signal.signal(signal.SIGALRM, signal_handler) + try: + yield + finally: + signal.setitimer(signal.ITIMER_REAL, 0) diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 47a2f2eb5..736b47746 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -33,7 +33,7 @@ providers: provider_type: remote::together config: url: https://api.together.xyz/v1 - api_key: ${env.TOGETHER_API_KEY} + api_key: ${env.TOGETHER_API_KEY:} vector_io: - provider_id: sqlite-vec provider_type: inline::sqlite-vec @@ -190,6 +190,21 @@ datasets: type: string chat_completion_input: type: string + - dataset_id: math_500 + provider_id: huggingface + url: + uri: https://huggingface.co/datasets/llamastack/math_500 + metadata: + path: llamastack/math_500 + name: + split: test + dataset_schema: + input_query: + type: string + expected_answer: + type: string + chat_completion_input: + type: string scoring_fns: [] benchmarks: - benchmark_id: meta-reference-simpleqa @@ -201,6 +216,9 @@ benchmarks: - benchmark_id: meta-reference-gpqa-cot dataset_id: gpqa_cot scoring_functions: ["basic::regex_parser_multiple_choice_answer"] + - benchmark_id: meta-reference-math-500 + dataset_id: math_500 + scoring_functions: ["basic::regex_parser_math_response"] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search