mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
temp commit
This commit is contained in:
parent
b99da3b9e4
commit
6555e2136f
3 changed files with 38 additions and 55 deletions
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
RegexParserScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
|
||||||
|
|
||||||
|
|
||||||
|
regex_parser_math_response = ScoringFn(
|
||||||
|
identifier="basic::regex_parser_math_response",
|
||||||
|
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
||||||
|
return_type=NumberType(),
|
||||||
|
provider_id="basic",
|
||||||
|
provider_resource_id="regex-parser-matH-response",
|
||||||
|
params=RegexParserScoringFnParams(
|
||||||
|
parsing_regexes=MATH_ANSWER_REGEXES,
|
||||||
|
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||||
|
),
|
||||||
|
)
|
|
@ -1,44 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
|
||||||
|
|
||||||
from .fn_defs.equality import equality
|
|
||||||
|
|
||||||
|
|
||||||
# class EqualityScoringFn(RegisteredBaseScoringFn):
|
|
||||||
class MathExactMatchFn(RegisteredBaseScoringFn):
|
|
||||||
"""
|
|
||||||
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.supported_fn_defs_registry = {
|
|
||||||
equality.identifier: equality,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def score_row(
|
|
||||||
self,
|
|
||||||
input_row: Dict[str, Any],
|
|
||||||
scoring_fn_identifier: Optional[str] = "equality",
|
|
||||||
scoring_params: Optional[ScoringFnParams] = None,
|
|
||||||
) -> ScoringResultRow:
|
|
||||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
|
||||||
assert "generated_answer" in input_row, "Generated answer not found in input row."
|
|
||||||
|
|
||||||
expected_answer = input_row["expected_answer"]
|
|
||||||
generated_answer = input_row["generated_answer"]
|
|
||||||
|
|
||||||
|
|
||||||
score = 1.0 if expected_answer == generated_answer else 0.0
|
|
||||||
return {
|
|
||||||
"score": score,
|
|
||||||
}
|
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseSc
|
||||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||||
regex_parser_multiple_choice_answer,
|
regex_parser_multiple_choice_answer,
|
||||||
)
|
)
|
||||||
from ...utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex
|
from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex
|
||||||
|
|
||||||
|
|
||||||
class RegexParserScoringFn(RegisteredBaseScoringFn):
|
class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||||
|
@ -45,25 +45,25 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
||||||
|
parsing_regexes = fn_def.params.parsing_regexes
|
||||||
|
|
||||||
pattern = r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"
|
assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||||
|
|
||||||
|
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||||
|
# parsing_regexes = r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"
|
||||||
|
|
||||||
normalized_generated_answer = normalize_final_answer(
|
normalized_generated_answer = normalize_final_answer(
|
||||||
first_answer(generated_answer),
|
first_answer(generated_answer),
|
||||||
pattern,
|
parsing_regexes,
|
||||||
match_first=True,
|
match_first=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||||
|
|
||||||
# parse answer according to regex
|
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
||||||
parsed_answer = None
|
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||||
for regex in fn_def.params.parsing_regexes:
|
|
||||||
match = re.search(regex, generated_answer)
|
|
||||||
if match:
|
|
||||||
parsed_answer = match.group(1)
|
|
||||||
break
|
|
||||||
|
|
||||||
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
|
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||||
return {
|
return {
|
||||||
"score": score,
|
"score": score,
|
||||||
}
|
}
|
Loading…
Add table
Add a link
Reference in a new issue