diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py index 56cf13503..3fbd219fb 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -13,21 +13,14 @@ from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.inference.inference import Inference from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.equality_scoring_fn import ( - EqualityScoringFn, -) - -from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import ( - LlmAsJudgeScoringFn, -) - -from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import ( - SubsetOfScoringFn, -) from .config import MetaReferenceScoringConfig +from .scoring_fn.equality_scoring_fn import EqualityScoringFn +from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn +from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn +from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn -FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn] +FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] @@ -65,6 +58,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs() ] + print("!!!", scoring_fn_defs_list) for f in scoring_fn_defs_list: assert f.identifier.startswith( diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py index 99fa6cc3a..b54bf7ae8 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py @@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef equality = ScoringFnDef( identifier="meta-reference::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - parameters=[], return_type=NumberType(), ) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py index cfef52160..68d77b8df 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -26,7 +26,6 @@ Total rating: llm_as_judge_8b_correctness = ScoringFnDef( identifier="meta-reference::llm_as_judge_8b_correctness", description="Llm As Judge Scoring Function", - parameters=[], return_type=NumberType(), params=LLMAsJudgeScoringFnParams( prompt_template=JUDGE_PROMPT, diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py new file mode 100644 index 000000000..84e518887 --- /dev/null +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import NumberType + +MULTILINGUAL_ANSWER_REGEXES = [ + r"Answer\s*:", + r"Answer\s*:​​​​​​", # Korean invisible character + r"উত্তর\s*:", + r"उत्तर\s*:", + r"উত্তরঃ", + r"উত্তর\s*:", + r"Antwort\s*:", + r"답변\s*:", + r"정답\s*:", + r"답\s*:", + r"答案\s*:", + r"答案\s*:", + r"答\s*:", + r"答\s*:", + r"答复\s*:", + r"答曰\s*:", + r"الإجابة:", + r"الجواب:", + r"إجابة:", + r"الإجابة النهائية:", + r"الإجابة الصحيحة:", + r"الإجابة الصحيحة هي:", + r"الإجابة هي:", + r"Respuesta\s*:", + r"Risposta\s*:", + r"答え\s*:", + r"答え\s*:", + r"回答\s*:", + r"回答\s*:", + r"解答\s*:", + r"Jawaban\s*:", + r"Réponse\s*:", + r"Resposta\s*:", + r"Jibu\s*:", + r"Idahun\s*:", + r"Ìdáhùn\s*:", + r"Idáhùn\s*:", + r"Àmọ̀nà\s*:", + r"Àdáhùn\s*:", + r"Ànúgọ\s*:", + r"Àṣàyàn\s*:", +] + +MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( + r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" +) + +regex_parser_multiple_choice_answer = ScoringFnDef( + identifier="meta-reference::regex_parser_multiple_choice_answer", + description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result", + return_type=NumberType(), + params=RegexParserScoringFnParams( + parsing_regexes=[ + MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) + for x in MULTILINGUAL_ANSWER_REGEXES + ], + ), +) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py new file mode 100644 index 000000000..70113cf48 --- /dev/null +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import re + +from .base_scoring_fn import BaseScoringFn +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from .common import aggregate_accuracy + +from .fn_defs.regex_parser_multiple_choice_answer import ( + regex_parser_multiple_choice_answer, +) + + +class RegexParserScoringFn(BaseScoringFn): + """ + A scoring_fn that parses answer from generated response according to context and check match with expected_answer. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = { + regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer, + } + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + ) -> ScoringResultRow: + assert ( + scoring_fn_identifier is not None + ), "Scoring function identifier not found." + fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + assert ( + fn_def.params is not None + and fn_def.params.type == ScoringConfigType.regex_parser.value + ), f"RegexParserScoringFnParams not found for {fn_def}." + + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + + # parse answer according to regex + parsed_answer = None + for regex in fn_def.params.parsing_regex: + match = re.search(regex, generated_answer) + if match: + parsed_answer = match.group(1) + break + + score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0 + return { + "score": score, + } + + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: + return aggregate_accuracy(scoring_results)