mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
init commit
This commit is contained in:
parent
3a454be9b2
commit
b99da3b9e4
3 changed files with 456 additions and 0 deletions
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||||
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
|
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||||
|
regex_parser_multiple_choice_answer,
|
||||||
|
)
|
||||||
|
from ...utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex
|
||||||
|
|
||||||
|
|
||||||
|
class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||||
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
|
if scoring_params is not None:
|
||||||
|
fn_def.params = scoring_params
|
||||||
|
|
||||||
|
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||||
|
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
||||||
|
|
||||||
|
pattern = r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"
|
||||||
|
normalized_generated_answer = normalize_final_answer(
|
||||||
|
first_answer(generated_answer),
|
||||||
|
pattern,
|
||||||
|
match_first=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||||
|
|
||||||
|
# parse answer according to regex
|
||||||
|
parsed_answer = None
|
||||||
|
for regex in fn_def.params.parsing_regexes:
|
||||||
|
match = re.search(regex, generated_answer)
|
||||||
|
if match:
|
||||||
|
parsed_answer = match.group(1)
|
||||||
|
break
|
||||||
|
|
||||||
|
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
|
||||||
|
return {
|
||||||
|
"score": score,
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
|
from .fn_defs.equality import equality
|
||||||
|
|
||||||
|
|
||||||
|
# class EqualityScoringFn(RegisteredBaseScoringFn):
|
||||||
|
class MathExactMatchFn(RegisteredBaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
equality.identifier: equality,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = "equality",
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||||
|
assert "generated_answer" in input_row, "Generated answer not found in input row."
|
||||||
|
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
||||||
|
|
||||||
|
score = 1.0 if expected_answer == generated_answer else 0.0
|
||||||
|
return {
|
||||||
|
"score": score,
|
||||||
|
}
|
343
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
343
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
|
@ -0,0 +1,343 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import re
|
||||||
|
from types import FrameType
|
||||||
|
from typing import Sequence, Iterator, Optional
|
||||||
|
import sympy
|
||||||
|
import signal
|
||||||
|
|
||||||
|
class TimeoutException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def time_limit(seconds: float) -> Iterator[None]:
|
||||||
|
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
|
||||||
|
raise TimeoutException("Timed out!")
|
||||||
|
|
||||||
|
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||||
|
signal.signal(signal.SIGALRM, signal_handler)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||||
|
|
||||||
|
# from minerva
|
||||||
|
SUBSTITUTIONS = [
|
||||||
|
("an ", ""),
|
||||||
|
("a ", ""),
|
||||||
|
(".$", "$"),
|
||||||
|
("\\$", ""),
|
||||||
|
(r"\ ", ""),
|
||||||
|
(" ", ""),
|
||||||
|
("mbox", "text"),
|
||||||
|
(",\\text{and}", ","),
|
||||||
|
("\\text{and}", ","),
|
||||||
|
("\\text{m}", "\\text{}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
REMOVED_EXPRESSIONS = [
|
||||||
|
"square",
|
||||||
|
"ways",
|
||||||
|
"integers",
|
||||||
|
"dollars",
|
||||||
|
"mph",
|
||||||
|
"inches",
|
||||||
|
"ft",
|
||||||
|
"hours",
|
||||||
|
"km",
|
||||||
|
"units",
|
||||||
|
"\\ldots",
|
||||||
|
"sue",
|
||||||
|
"points",
|
||||||
|
"feet",
|
||||||
|
"minutes",
|
||||||
|
"digits",
|
||||||
|
"cents",
|
||||||
|
"degrees",
|
||||||
|
"cm",
|
||||||
|
"gm",
|
||||||
|
"pounds",
|
||||||
|
"meters",
|
||||||
|
"meals",
|
||||||
|
"edges",
|
||||||
|
"students",
|
||||||
|
"childrentickets",
|
||||||
|
"multiples",
|
||||||
|
"\\text{s}",
|
||||||
|
"\\text{.}",
|
||||||
|
"\\text{\ns}",
|
||||||
|
"\\text{}^2",
|
||||||
|
"\\text{}^3",
|
||||||
|
"\\text{\n}",
|
||||||
|
"\\text{}",
|
||||||
|
r"\mathrm{th}",
|
||||||
|
r"^\circ",
|
||||||
|
r"^{\circ}",
|
||||||
|
r"\;",
|
||||||
|
r",\!",
|
||||||
|
"{,}",
|
||||||
|
'"',
|
||||||
|
"\\dots",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
|
||||||
|
if isinstance(expression, float):
|
||||||
|
return expression
|
||||||
|
new_expression = f"{expression}"
|
||||||
|
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
|
||||||
|
for match in re.finditer(regex, expression):
|
||||||
|
try:
|
||||||
|
value = float(match.group(1)) / float(match.group(2))
|
||||||
|
new_expression = new_expression.replace(
|
||||||
|
match.group(),
|
||||||
|
f"{{value:{fmt}}}".format(value=value),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return new_expression
|
||||||
|
|
||||||
|
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
|
||||||
|
try:
|
||||||
|
with time_limit(seconds=5):
|
||||||
|
from sympy.parsing.latex import parse_latex
|
||||||
|
|
||||||
|
value = parse_latex(expression).evalf() # type: ignore
|
||||||
|
return f"{{value:{fmt}}}".format(value=value)
|
||||||
|
except Exception:
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
|
||||||
|
for marker in markers:
|
||||||
|
text = text.split(marker)[0]
|
||||||
|
return text
|
||||||
|
|
||||||
|
def extract_result_from_boxed(answer: str) -> str:
|
||||||
|
box_start = "\\boxed"
|
||||||
|
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
|
||||||
|
start = answer.rfind(box_start)
|
||||||
|
if start < 0:
|
||||||
|
return ""
|
||||||
|
answer = answer[start + len(box_start) :].strip()
|
||||||
|
ends_with_curly = answer.startswith("{")
|
||||||
|
i = 0
|
||||||
|
open_braces = 0
|
||||||
|
while i < len(answer):
|
||||||
|
if answer[i] == "{":
|
||||||
|
open_braces += 1
|
||||||
|
elif answer[i] == "}":
|
||||||
|
open_braces -= 1
|
||||||
|
if open_braces == 0:
|
||||||
|
if ends_with_curly:
|
||||||
|
answer = answer[: i + 1].strip()
|
||||||
|
break
|
||||||
|
elif answer[i] == "$":
|
||||||
|
answer = answer[:i].strip()
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
# remove extra curly braces
|
||||||
|
while True:
|
||||||
|
if answer.startswith("{") and answer.endswith("}"):
|
||||||
|
answer = answer[1:-1].strip()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
# from minerva paper + _normalise_result from xavierm
|
||||||
|
def normalize_final_answer(
|
||||||
|
final_answer: str, regex_pattern: str, match_first: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""Extract and normalize a final answer to a quantitative reasoning question."""
|
||||||
|
match = re.findall(regex_pattern, final_answer)
|
||||||
|
extraction: str
|
||||||
|
if len(match) > 0:
|
||||||
|
if match_first:
|
||||||
|
extraction = match[0]
|
||||||
|
else:
|
||||||
|
extraction = match[-1]
|
||||||
|
else:
|
||||||
|
extraction = extract_result_from_boxed(final_answer)
|
||||||
|
|
||||||
|
if len(extraction) == 0:
|
||||||
|
return final_answer
|
||||||
|
else:
|
||||||
|
final_answer = extraction
|
||||||
|
final_answer = final_answer.split("=")[-1]
|
||||||
|
for before, after in SUBSTITUTIONS:
|
||||||
|
final_answer = final_answer.replace(before, after)
|
||||||
|
for expr in REMOVED_EXPRESSIONS:
|
||||||
|
final_answer = final_answer.replace(expr, "")
|
||||||
|
# Extract answer that is in LaTeX math, is bold,
|
||||||
|
# is surrounded by a box, etc.
|
||||||
|
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
||||||
|
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
||||||
|
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
||||||
|
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
||||||
|
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
||||||
|
# Normalize shorthand TeX:
|
||||||
|
# \fracab -> \frac{a}{b}
|
||||||
|
# \frac{abc}{bef} -> \frac{abc}{bef}
|
||||||
|
# \fracabc -> \frac{a}{b}c
|
||||||
|
# \sqrta -> \sqrt{a}
|
||||||
|
# \sqrtab -> sqrt{a}b
|
||||||
|
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
||||||
|
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
||||||
|
final_answer = final_answer.replace("$", "")
|
||||||
|
# Normalize 100,000 -> 100000
|
||||||
|
if final_answer.replace(",", "").isdigit():
|
||||||
|
final_answer = final_answer.replace(",", "")
|
||||||
|
# If the final answer is a single letter in parentheses, remove the parentheses
|
||||||
|
# Example: (a) -> a (but not (ab) -> ab)
|
||||||
|
if re.match(r"\([a-zA-Z]\)", final_answer):
|
||||||
|
final_answer = final_answer[1]
|
||||||
|
return _normalise_result(final_answer)
|
||||||
|
|
||||||
|
def _normalise_result(string: str) -> str:
|
||||||
|
# linebreaks
|
||||||
|
string = string.replace("\n", "")
|
||||||
|
|
||||||
|
# remove inverse spaces
|
||||||
|
string = string.replace("\\!", "")
|
||||||
|
|
||||||
|
# replace \\ with \
|
||||||
|
string = string.replace("\\\\", "\\")
|
||||||
|
|
||||||
|
# replace tfrac and dfrac with frac
|
||||||
|
string = string.replace("cfrac", "frac")
|
||||||
|
string = string.replace("tfrac", "frac")
|
||||||
|
string = string.replace("dfrac", "frac")
|
||||||
|
|
||||||
|
# remove \left and \right
|
||||||
|
string = string.replace("\\left", "")
|
||||||
|
string = string.replace("\\right", "")
|
||||||
|
|
||||||
|
# Remove circ (degrees)
|
||||||
|
string = string.replace("^{\\circ}", "")
|
||||||
|
string = string.replace("^\\circ", "")
|
||||||
|
|
||||||
|
# remove dollar signs
|
||||||
|
string = string.replace("\\$", "")
|
||||||
|
|
||||||
|
# remove units (on the right)
|
||||||
|
string = _remove_right_units(string)
|
||||||
|
|
||||||
|
# remove percentage
|
||||||
|
string = string.replace("\\%", "")
|
||||||
|
string = string.replace(r"\%", "")
|
||||||
|
|
||||||
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||||
|
string = string.replace(" .", " 0.")
|
||||||
|
string = string.replace("{.", "{0.")
|
||||||
|
# if empty, return empty string
|
||||||
|
if len(string) == 0:
|
||||||
|
return string
|
||||||
|
if string[0] == ".":
|
||||||
|
string = "0" + string
|
||||||
|
|
||||||
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||||
|
string = string.split("=")[-1]
|
||||||
|
|
||||||
|
# fix sqrt3 --> sqrt{3}
|
||||||
|
string = _fix_sqrt(string)
|
||||||
|
|
||||||
|
# remove spaces
|
||||||
|
string = string.replace(" ", "")
|
||||||
|
|
||||||
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||||
|
string = _fix_fracs(string)
|
||||||
|
|
||||||
|
# manually change 0.5 --> \frac{1}{2}
|
||||||
|
if string == "0.5":
|
||||||
|
string = "\\frac{1}{2}"
|
||||||
|
|
||||||
|
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||||
|
string = _fix_a_slash_b(string)
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
def _remove_right_units(string: str) -> str:
|
||||||
|
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||||
|
try:
|
||||||
|
if "\\text{ " in string:
|
||||||
|
splits = string.split("\\text{ ")
|
||||||
|
assert len(splits) == 2
|
||||||
|
return splits[0]
|
||||||
|
else:
|
||||||
|
return string
|
||||||
|
except AssertionError:
|
||||||
|
return string
|
||||||
|
|
||||||
|
def _fix_sqrt(string: str) -> str:
|
||||||
|
if "\\sqrt" not in string:
|
||||||
|
return string
|
||||||
|
splits = string.split("\\sqrt")
|
||||||
|
new_string = splits[0]
|
||||||
|
for split in splits[1:]:
|
||||||
|
if len(split) == 0:
|
||||||
|
return string
|
||||||
|
if split[0] != "{":
|
||||||
|
a = split[0]
|
||||||
|
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||||
|
else:
|
||||||
|
new_substr = "\\sqrt" + split
|
||||||
|
new_string += new_substr
|
||||||
|
return new_string
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_fracs(string: str) -> str:
|
||||||
|
substrs = string.split("\\frac")
|
||||||
|
new_str = substrs[0]
|
||||||
|
if len(substrs) > 1:
|
||||||
|
substrs = substrs[1:]
|
||||||
|
for substr in substrs:
|
||||||
|
new_str += "\\frac"
|
||||||
|
if len(substr) == 0:
|
||||||
|
return string
|
||||||
|
if substr[0] == "{":
|
||||||
|
new_str += substr
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
assert len(substr) >= 2
|
||||||
|
except AssertionError:
|
||||||
|
return string
|
||||||
|
a = substr[0]
|
||||||
|
b = substr[1]
|
||||||
|
if b != "{":
|
||||||
|
if len(substr) > 2:
|
||||||
|
post_substr = substr[2:]
|
||||||
|
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||||
|
else:
|
||||||
|
new_str += "{" + a + "}{" + b + "}"
|
||||||
|
else:
|
||||||
|
if len(substr) > 2:
|
||||||
|
post_substr = substr[2:]
|
||||||
|
new_str += "{" + a + "}" + b + post_substr
|
||||||
|
else:
|
||||||
|
new_str += "{" + a + "}" + b
|
||||||
|
string = new_str
|
||||||
|
return string
|
||||||
|
|
||||||
|
def _fix_a_slash_b(string: str) -> str:
|
||||||
|
if len(string.split("/")) != 2:
|
||||||
|
return string
|
||||||
|
a = string.split("/")[0]
|
||||||
|
b = string.split("/")[1]
|
||||||
|
try:
|
||||||
|
ia = int(a)
|
||||||
|
ib = int(b)
|
||||||
|
assert string == "{}/{}".format(ia, ib)
|
||||||
|
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||||
|
return new_string
|
||||||
|
except (ValueError, AssertionError):
|
||||||
|
return string
|
Loading…
Add table
Add a link
Reference in a new issue