temp commit

This commit is contained in:
Botao Chen 2025-03-09 23:56:36 -07:00
parent 599873e485
commit 10a9f6a5c8
5 changed files with 64 additions and 11 deletions

View file

@ -25,8 +25,9 @@ from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
class BasicScoringImpl( class BasicScoringImpl(

View file

@ -19,7 +19,7 @@ regex_parser_math_response = ScoringFn(
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match", description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
return_type=NumberType(), return_type=NumberType(),
provider_id="basic", provider_id="basic",
provider_resource_id="regex-parser-matH-response", provider_resource_id="regex-parser-math-response",
params=RegexParserScoringFnParams( params=RegexParserScoringFnParams(
parsing_regexes=MATH_ANSWER_REGEXES, parsing_regexes=MATH_ANSWER_REGEXES,
aggregation_functions=[AggregationFunctionType.accuracy], aggregation_functions=[AggregationFunctionType.accuracy],

View file

@ -10,21 +10,21 @@ from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.regex_parser_multiple_choice_answer import ( from .fn_defs.regex_parser_math_response import (
regex_parser_multiple_choice_answer, regex_parser_math_response,
) )
from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex, extract_result_from_boxed
class RegexParserScoringFn(RegisteredBaseScoringFn): class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that parses answer from generated response according to context and check match with expected_answer. A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
""" """
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = { self.supported_fn_defs_registry = {
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer, regex_parser_math_response.identifier: regex_parser_math_response,
} }
async def score_row( async def score_row(
@ -33,6 +33,7 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = None, scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None, scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
print('I reach RegexParserMathResponseScoringFn')
assert scoring_fn_identifier is not None, "Scoring function identifier not found." assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None: if scoring_params is not None:
@ -43,26 +44,58 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
) )
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
expected_answer = r"""
We have that $r = \sqrt{0^2 + 3^2} = 3.$ Also, if we draw the line connecting the origin and $(0,3),$ this line makes an angle of $\frac{\pi}{2}$ with the positive $x$-axis.
[asy]
unitsize(0.8 cm);
draw((-0.5,0)--(3.5,0));
draw((0,-0.5)--(0,3.5));
draw(arc((0,0),3,0,90),red,Arrow(6));
dot((0,3), red);
label("$(0,3)$", (0,3), W);
dot((3,0), red);
[/asy]
Therefore, the polar coordinates are $\boxed{\left( 3, \frac{\pi}{2} \right)}.$
"""
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]
print('expected_answer', expected_answer)
print('generated_answer', generated_answer)
parsing_regexes = fn_def.params.parsing_regexes parsing_regexes = fn_def.params.parsing_regexes
assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function." assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function."
parsing_regexes = fn_def.params.parsing_regexes[0] parsing_regexes = fn_def.params.parsing_regexes[0]
# parsing_regexes = r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"
print('parsing_regexes', parsing_regexes)
normalized_generated_answer = normalize_final_answer( normalized_generated_answer = normalize_final_answer(
first_answer(generated_answer), first_answer(generated_answer),
parsing_regexes, parsing_regexes,
match_first=True, match_first=True,
) )
print('normalized_generated_answer_1', normalized_generated_answer)
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer)) normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
normalized_expected_answer = normalize_final_answer(expected_answer, r".*") print('normalized_generated_answer_2', normalized_generated_answer)
# print('extract_result_from_boxed', extract_result_from_boxed(expected_answer))
# normalized_expected_answer = normalize_final_answer(extract_result_from_boxed(expected_answer), r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$")
normalized_expected_answer = normalize_final_answer(expected_answer, r"\$\\boxed{(?P<X>.*)}\$")
print('normalized_expected_answer_1', normalized_expected_answer)
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer)) normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
print('normalized_expected_answer_2', normalized_expected_answer)
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0 score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
return { return {
"score": score, "score": score,

View file

@ -32,6 +32,7 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = None, scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None, scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
print("I reach RegexParserScoringFn")
assert scoring_fn_identifier is not None, "Scoring function identifier not found." assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None: if scoring_params is not None:

View file

@ -33,7 +33,7 @@ providers:
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY} api_key: ${env.TOGETHER_API_KEY:}
vector_io: vector_io:
- provider_id: sqlite-vec - provider_id: sqlite-vec
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
@ -190,6 +190,21 @@ datasets:
type: string type: string
chat_completion_input: chat_completion_input:
type: string type: string
- dataset_id: math_500
provider_id: huggingface
url:
uri: https://huggingface.co/datasets/llamastack/math_500
metadata:
path: llamastack/math_500
name:
split: test
dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
scoring_fns: [] scoring_fns: []
benchmarks: benchmarks:
- benchmark_id: meta-reference-simpleqa - benchmark_id: meta-reference-simpleqa
@ -201,6 +216,9 @@ benchmarks:
- benchmark_id: meta-reference-gpqa-cot - benchmark_id: meta-reference-gpqa-cot
dataset_id: gpqa_cot dataset_id: gpqa_cot
scoring_functions: ["basic::regex_parser_multiple_choice_answer"] scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
- benchmark_id: meta-reference-math-500
dataset_id: math_500
scoring_functions: ["basic::regex_parser_math_response"]
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search