diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 13cd78243..cbaf015d6 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -25,8 +25,9 @@ from .config import BasicScoringConfig from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn +from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn -FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] +FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn] class BasicScoringImpl( diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py index 7d5f49c8a..8b1bf5352 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py @@ -19,7 +19,7 @@ regex_parser_math_response = ScoringFn( description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match", return_type=NumberType(), provider_id="basic", - provider_resource_id="regex-parser-matH-response", + provider_resource_id="regex-parser-math-response", params=RegexParserScoringFnParams( parsing_regexes=MATH_ANSWER_REGEXES, aggregation_functions=[AggregationFunctionType.accuracy], diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py index 35789fb2c..6ad3813ca 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py @@ -10,21 +10,21 @@ from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn -from .fn_defs.regex_parser_multiple_choice_answer import ( - regex_parser_multiple_choice_answer, +from .fn_defs.regex_parser_math_response import ( + regex_parser_math_response, ) -from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex +from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex, extract_result_from_boxed -class RegexParserScoringFn(RegisteredBaseScoringFn): +class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn): """ - A scoring_fn that parses answer from generated response according to context and check match with expected_answer. + A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer. """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.supported_fn_defs_registry = { - regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer, + regex_parser_math_response.identifier: regex_parser_math_response, } async def score_row( @@ -33,6 +33,7 @@ class RegexParserScoringFn(RegisteredBaseScoringFn): scoring_fn_identifier: Optional[str] = None, scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: + print('I reach RegexParserMathResponseScoringFn') assert scoring_fn_identifier is not None, "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] if scoring_params is not None: @@ -43,26 +44,58 @@ class RegexParserScoringFn(RegisteredBaseScoringFn): ) expected_answer = input_row["expected_answer"] + expected_answer = r""" + We have that $r = \sqrt{0^2 + 3^2} = 3.$ Also, if we draw the line connecting the origin and $(0,3),$ this line makes an angle of $\frac{\pi}{2}$ with the positive $x$-axis. + + [asy] + unitsize(0.8 cm); + + draw((-0.5,0)--(3.5,0)); + draw((0,-0.5)--(0,3.5)); + draw(arc((0,0),3,0,90),red,Arrow(6)); + + dot((0,3), red); + label("$(0,3)$", (0,3), W); + dot((3,0), red); + [/asy] + + Therefore, the polar coordinates are $\boxed{\left( 3, \frac{\pi}{2} \right)}.$ + """ generated_answer = input_row["generated_answer"] + print('expected_answer', expected_answer) + print('generated_answer', generated_answer) + parsing_regexes = fn_def.params.parsing_regexes assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function." parsing_regexes = fn_def.params.parsing_regexes[0] - # parsing_regexes = r".*final answer is:?\s*\$\\boxed{(?P.*)}\$" + + print('parsing_regexes', parsing_regexes) + normalized_generated_answer = normalize_final_answer( first_answer(generated_answer), parsing_regexes, match_first=True, ) + print('normalized_generated_answer_1', normalized_generated_answer) + normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer)) - normalized_expected_answer = normalize_final_answer(expected_answer, r".*") + print('normalized_generated_answer_2', normalized_generated_answer) + + + # print('extract_result_from_boxed', extract_result_from_boxed(expected_answer)) + # normalized_expected_answer = normalize_final_answer(extract_result_from_boxed(expected_answer), r".*final answer is:?\s*\$\\boxed{(?P.*)}\$") + normalized_expected_answer = normalize_final_answer(expected_answer, r"\$\\boxed{(?P.*)}\$") + print('normalized_expected_answer_1', normalized_expected_answer) normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer)) + print('normalized_expected_answer_2', normalized_expected_answer) + score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0 return { "score": score, diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index 0606a9581..bc105ebfb 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -32,6 +32,7 @@ class RegexParserScoringFn(RegisteredBaseScoringFn): scoring_fn_identifier: Optional[str] = None, scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: + print("I reach RegexParserScoringFn") assert scoring_fn_identifier is not None, "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] if scoring_params is not None: diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 47a2f2eb5..736b47746 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -33,7 +33,7 @@ providers: provider_type: remote::together config: url: https://api.together.xyz/v1 - api_key: ${env.TOGETHER_API_KEY} + api_key: ${env.TOGETHER_API_KEY:} vector_io: - provider_id: sqlite-vec provider_type: inline::sqlite-vec @@ -190,6 +190,21 @@ datasets: type: string chat_completion_input: type: string + - dataset_id: math_500 + provider_id: huggingface + url: + uri: https://huggingface.co/datasets/llamastack/math_500 + metadata: + path: llamastack/math_500 + name: + split: test + dataset_schema: + input_query: + type: string + expected_answer: + type: string + chat_completion_input: + type: string scoring_fns: [] benchmarks: - benchmark_id: meta-reference-simpleqa @@ -201,6 +216,9 @@ benchmarks: - benchmark_id: meta-reference-gpqa-cot dataset_id: gpqa_cot scoring_functions: ["basic::regex_parser_multiple_choice_answer"] + - benchmark_id: meta-reference-math-500 + dataset_id: math_500 + scoring_functions: ["basic::regex_parser_math_response"] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search