mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
temp commit
This commit is contained in:
parent
599873e485
commit
10a9f6a5c8
5 changed files with 64 additions and 11 deletions
|
@ -25,8 +25,9 @@ from .config import BasicScoringConfig
|
||||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||||
|
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||||
|
|
||||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
|
||||||
|
|
||||||
|
|
||||||
class BasicScoringImpl(
|
class BasicScoringImpl(
|
||||||
|
|
|
@ -19,7 +19,7 @@ regex_parser_math_response = ScoringFn(
|
||||||
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
provider_id="basic",
|
provider_id="basic",
|
||||||
provider_resource_id="regex-parser-matH-response",
|
provider_resource_id="regex-parser-math-response",
|
||||||
params=RegexParserScoringFnParams(
|
params=RegexParserScoringFnParams(
|
||||||
parsing_regexes=MATH_ANSWER_REGEXES,
|
parsing_regexes=MATH_ANSWER_REGEXES,
|
||||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||||
|
|
|
@ -10,21 +10,21 @@ from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
from .fn_defs.regex_parser_math_response import (
|
||||||
regex_parser_multiple_choice_answer,
|
regex_parser_math_response,
|
||||||
)
|
)
|
||||||
from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex
|
from ..utils.math_utils import normalize_final_answer, first_answer, try_evaluate_frac, try_evaluate_latex, extract_result_from_boxed
|
||||||
|
|
||||||
|
|
||||||
class RegexParserScoringFn(RegisteredBaseScoringFn):
|
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
||||||
"""
|
"""
|
||||||
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.supported_fn_defs_registry = {
|
self.supported_fn_defs_registry = {
|
||||||
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
|
regex_parser_math_response.identifier: regex_parser_math_response,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def score_row(
|
async def score_row(
|
||||||
|
@ -33,6 +33,7 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||||
scoring_fn_identifier: Optional[str] = None,
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
scoring_params: Optional[ScoringFnParams] = None,
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
) -> ScoringResultRow:
|
) -> ScoringResultRow:
|
||||||
|
print('I reach RegexParserMathResponseScoringFn')
|
||||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
if scoring_params is not None:
|
if scoring_params is not None:
|
||||||
|
@ -43,26 +44,58 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
|
expected_answer = r"""
|
||||||
|
We have that $r = \sqrt{0^2 + 3^2} = 3.$ Also, if we draw the line connecting the origin and $(0,3),$ this line makes an angle of $\frac{\pi}{2}$ with the positive $x$-axis.
|
||||||
|
|
||||||
|
[asy]
|
||||||
|
unitsize(0.8 cm);
|
||||||
|
|
||||||
|
draw((-0.5,0)--(3.5,0));
|
||||||
|
draw((0,-0.5)--(0,3.5));
|
||||||
|
draw(arc((0,0),3,0,90),red,Arrow(6));
|
||||||
|
|
||||||
|
dot((0,3), red);
|
||||||
|
label("$(0,3)$", (0,3), W);
|
||||||
|
dot((3,0), red);
|
||||||
|
[/asy]
|
||||||
|
|
||||||
|
Therefore, the polar coordinates are $\boxed{\left( 3, \frac{\pi}{2} \right)}.$
|
||||||
|
"""
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
||||||
|
print('expected_answer', expected_answer)
|
||||||
|
print('generated_answer', generated_answer)
|
||||||
|
|
||||||
parsing_regexes = fn_def.params.parsing_regexes
|
parsing_regexes = fn_def.params.parsing_regexes
|
||||||
|
|
||||||
assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function."
|
assert len(parsing_regexes) == 1, "Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||||
|
|
||||||
parsing_regexes = fn_def.params.parsing_regexes[0]
|
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||||
# parsing_regexes = r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"
|
|
||||||
|
print('parsing_regexes', parsing_regexes)
|
||||||
|
|
||||||
|
|
||||||
normalized_generated_answer = normalize_final_answer(
|
normalized_generated_answer = normalize_final_answer(
|
||||||
first_answer(generated_answer),
|
first_answer(generated_answer),
|
||||||
parsing_regexes,
|
parsing_regexes,
|
||||||
match_first=True,
|
match_first=True,
|
||||||
)
|
)
|
||||||
|
print('normalized_generated_answer_1', normalized_generated_answer)
|
||||||
|
|
||||||
|
|
||||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||||
|
|
||||||
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
print('normalized_generated_answer_2', normalized_generated_answer)
|
||||||
|
|
||||||
|
|
||||||
|
# print('extract_result_from_boxed', extract_result_from_boxed(expected_answer))
|
||||||
|
# normalized_expected_answer = normalize_final_answer(extract_result_from_boxed(expected_answer), r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$")
|
||||||
|
normalized_expected_answer = normalize_final_answer(expected_answer, r"\$\\boxed{(?P<X>.*)}\$")
|
||||||
|
print('normalized_expected_answer_1', normalized_expected_answer)
|
||||||
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||||
|
|
||||||
|
print('normalized_expected_answer_2', normalized_expected_answer)
|
||||||
|
|
||||||
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||||
return {
|
return {
|
||||||
"score": score,
|
"score": score,
|
||||||
|
|
|
@ -32,6 +32,7 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
||||||
scoring_fn_identifier: Optional[str] = None,
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
scoring_params: Optional[ScoringFnParams] = None,
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
) -> ScoringResultRow:
|
) -> ScoringResultRow:
|
||||||
|
print("I reach RegexParserScoringFn")
|
||||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
if scoring_params is not None:
|
if scoring_params is not None:
|
||||||
|
|
|
@ -33,7 +33,7 @@ providers:
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY}
|
api_key: ${env.TOGETHER_API_KEY:}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: sqlite-vec
|
- provider_id: sqlite-vec
|
||||||
provider_type: inline::sqlite-vec
|
provider_type: inline::sqlite-vec
|
||||||
|
@ -190,6 +190,21 @@ datasets:
|
||||||
type: string
|
type: string
|
||||||
chat_completion_input:
|
chat_completion_input:
|
||||||
type: string
|
type: string
|
||||||
|
- dataset_id: math_500
|
||||||
|
provider_id: huggingface
|
||||||
|
url:
|
||||||
|
uri: https://huggingface.co/datasets/llamastack/math_500
|
||||||
|
metadata:
|
||||||
|
path: llamastack/math_500
|
||||||
|
name:
|
||||||
|
split: test
|
||||||
|
dataset_schema:
|
||||||
|
input_query:
|
||||||
|
type: string
|
||||||
|
expected_answer:
|
||||||
|
type: string
|
||||||
|
chat_completion_input:
|
||||||
|
type: string
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
benchmarks:
|
benchmarks:
|
||||||
- benchmark_id: meta-reference-simpleqa
|
- benchmark_id: meta-reference-simpleqa
|
||||||
|
@ -201,6 +216,9 @@ benchmarks:
|
||||||
- benchmark_id: meta-reference-gpqa-cot
|
- benchmark_id: meta-reference-gpqa-cot
|
||||||
dataset_id: gpqa_cot
|
dataset_id: gpqa_cot
|
||||||
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
|
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
|
||||||
|
- benchmark_id: meta-reference-math-500
|
||||||
|
dataset_id: math_500
|
||||||
|
scoring_functions: ["basic::regex_parser_math_response"]
|
||||||
tool_groups:
|
tool_groups:
|
||||||
- toolgroup_id: builtin::websearch
|
- toolgroup_id: builtin::websearch
|
||||||
provider_id: tavily-search
|
provider_id: tavily-search
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue