From 779e66f83f921c70b6cdab24a4b2eca12ef01fab Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 31 Oct 2024 16:44:02 -0700 Subject: [PATCH] registration answer parsing --- .../scoring_functions/scoring_functions.py | 18 ++++- .../impls/meta_reference/scoring/scoring.py | 41 ++++++----- .../scoring_fn/answer_parsing_scoring_fn.py | 61 ++++++++++++++++ .../scoring/scoring_fn/equality_scoring_fn.py | 4 +- .../fn_defs/answer_parsing_multiple_choice.py | 69 +++++++++++++++++++ .../scoring/scoring_fn/fn_defs/equality.py | 1 - .../fn_defs/llm_as_judge_8b_correctness.py | 1 - .../scoring/scoring_fn/fn_defs/subset_of.py | 1 - .../scoring_fn/llm_as_judge_scoring_fn.py | 15 ++-- .../scoring_fn/subset_of_scoring_fn.py | 12 +--- 10 files changed, 178 insertions(+), 45 deletions(-) create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/answer_parsing_scoring_fn.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/answer_parsing_multiple_choice.py diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 597a1abbe..746cbe7a8 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,7 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -42,7 +51,7 @@ class AnswerParsingContext(BaseModel): ScoringContextType.answer_parsing.value ) parsing_regex: Optional[List[str]] = Field( - "Regex to extract the answer from generated response", + description="Regex to extract the answer from generated response", default_factory=list, ) @@ -67,7 +76,10 @@ class ScoringFnDef(BaseModel): return_type: ParamType = Field( description="The return type of the deterministic function", ) - context: Optional[ScoringContext] = None + context: Optional[ScoringContext] = Field( + description="Scoring function context used different answer extraction", + default=None, + ) # We can optionally add information here to support packaging of code, etc. diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 41b24a512..e7c85b6fc 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -13,23 +13,20 @@ from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.inference.inference import Inference from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import ( - EqualityScoringFn, -) - -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import ( - LlmAsJudgeScoringFn, -) - -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import ( - SubsetOfScoringFn, -) from .config import MetaReferenceScoringConfig +from .scoring_fn.answer_parsing_scoring_fn import AnswerParsingScoringFn +from .scoring_fn.equality_scoring_fn import EqualityScoringFn +from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn +from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn] -LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] +# Scoring functions with context that can be registered +REGISTERABLE_SCORING_FNS = { + ScoringContextType.llm_as_judge.value: LlmAsJudgeScoringFn, + ScoringContextType.answer_parsing.value: AnswerParsingScoringFn, +} class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): @@ -44,18 +41,24 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): self.datasetio_api = datasetio_api self.datasets_api = datasets_api self.inference_api = inference_api + # keep track of scoring function id to impls self.scoring_fn_id_impls = {} + # registerable scoring fn context to impls + self.registerable_scoring_fn_impls = {} async def initialize(self) -> None: for x in FIXED_FNS: impl = x() for fn_defs in impl.get_supported_scoring_fn_defs(): self.scoring_fn_id_impls[fn_defs.identifier] = impl - for x in LLM_JUDGE_FNS: - impl = x(inference_api=self.inference_api) + for context_type, impl_cls in REGISTERABLE_SCORING_FNS.items(): + if context_type == ScoringContextType.llm_as_judge.value: + impl = impl_cls(inference_api=self.inference_api) + else: + impl = impl_cls() for fn_defs in impl.get_supported_scoring_fn_defs(): self.scoring_fn_id_impls[fn_defs.identifier] = impl - self.llm_as_judge_fn = impl + self.registerable_scoring_fn_impls[context_type] = impl async def shutdown(self) -> None: ... @@ -74,8 +77,12 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): return scoring_fn_defs_list async def register_scoring_function(self, function_def: ScoringFnDef) -> None: - self.llm_as_judge_fn.register_scoring_fn_def(function_def) - self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn + assert ( + function_def.context is not None + ), "Only ScoringFnDef with context set can be registered" + fn_impl = self.registerable_scoring_fn_impls[function_def.context.type] + fn_impl.register_scoring_fn_def(function_def) + self.scoring_fn_id_impls[function_def.identifier] = fn_impl async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/answer_parsing_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/answer_parsing_scoring_fn.py new file mode 100644 index 000000000..e8e2667e8 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/answer_parsing_scoring_fn.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import re + +from .base_scoring_fn import BaseScoringFn +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from .common import aggregate_accuracy + +from .fn_defs.answer_parsing_multiple_choice import answer_parsing_multiple_choice + + +class AnswerParsingScoringFn(BaseScoringFn): + """ + A scoring_fn that parses answer from generated response according to context and check match with expected_answer. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = { + answer_parsing_multiple_choice.identifier: answer_parsing_multiple_choice, + } + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + ) -> ScoringResultRow: + assert ( + scoring_fn_identifier is not None + ), "Scoring function identifier not found." + fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + assert ( + fn_def.context is not None + and fn_def.context.type == ScoringContextType.answer_parsing.value + ), f"AnswerParsingContext not found for {fn_def}." + + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + + # parse answer according to regex + parsed_answer = None + for regex in fn_def.context.parsing_regex: + match = re.search(regex, generated_answer) + if match: + parsed_answer = match.group(1) + break + + score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0 + return { + "score": score, + } + + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: + return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index 556436286..7dd446151 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -15,9 +15,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import aggregate_accuracy, ) -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import ( - equality, -) +from .fn_defs.equality import equality class EqualityScoringFn(BaseScoringFn): diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/answer_parsing_multiple_choice.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/answer_parsing_multiple_choice.py new file mode 100644 index 000000000..1e6040efe --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/answer_parsing_multiple_choice.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import NumberType + +MULTILINGUAL_ANSWER_REGEXES = [ + r"Answer\s*:", + r"Answer\s*:​​​​​​", # Korean invisible character + r"উত্তর\s*:", + r"उत्तर\s*:", + r"উত্তরঃ", + r"উত্তর\s*:", + r"Antwort\s*:", + r"답변\s*:", + r"정답\s*:", + r"답\s*:", + r"答案\s*:", + r"答案\s*:", + r"答\s*:", + r"答\s*:", + r"答复\s*:", + r"答曰\s*:", + r"الإجابة:", + r"الجواب:", + r"إجابة:", + r"الإجابة النهائية:", + r"الإجابة الصحيحة:", + r"الإجابة الصحيحة هي:", + r"الإجابة هي:", + r"Respuesta\s*:", + r"Risposta\s*:", + r"答え\s*:", + r"答え\s*:", + r"回答\s*:", + r"回答\s*:", + r"解答\s*:", + r"Jawaban\s*:", + r"Réponse\s*:", + r"Resposta\s*:", + r"Jibu\s*:", + r"Idahun\s*:", + r"Ìdáhùn\s*:", + r"Idáhùn\s*:", + r"Àmọ̀nà\s*:", + r"Àdáhùn\s*:", + r"Ànúgọ\s*:", + r"Àṣàyàn\s*:", +] + +MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( + r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" +) + +answer_parsing_multiple_choice = ScoringFnDef( + identifier="meta-reference::answer_parsing_multiple_choice", + description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result", + return_type=NumberType(), + context=AnswerParsingContext( + parsing_regex=[ + MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) + for x in MULTILINGUAL_ANSWER_REGEXES + ], + ), +) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py index 99fa6cc3a..b54bf7ae8 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py @@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef equality = ScoringFnDef( identifier="meta-reference::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - parameters=[], return_type=NumberType(), ) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py index 20a67edc7..2838f04ed 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -26,7 +26,6 @@ Total rating: llm_as_judge_8b_correctness = ScoringFnDef( identifier="meta-reference::llm_as_judge_8b_correctness", description="Llm As Judge Scoring Function", - parameters=[], return_type=NumberType(), context=LLMAsJudgeContext( prompt_template=JUDGE_PROMPT, diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py index 5a3e2e8fb..2337c3e30 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py @@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef subset_of = ScoringFnDef( identifier="meta-reference::subset_of", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", - parameters=[], return_type=NumberType(), ) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index 5a5ce2550..7d4ec94cf 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -4,25 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from llama_stack.apis.inference.inference import Inference -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( - BaseScoringFn, -) + +from .base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 import re -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( - aggregate_average, -) -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import ( - llm_as_judge_8b_correctness, -) +from .common import aggregate_average +from .fn_defs.llm_as_judge_8b_correctness import llm_as_judge_8b_correctness class LlmAsJudgeScoringFn(BaseScoringFn): """ - A scoring_fn that assigns + A scoring_fn using LLM as Judge to produce score """ def __init__(self, inference_api: Inference, *arg, **kwargs) -> None: diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index fcef2ead7..e8db2cb0f 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -4,19 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( - BaseScoringFn, -) +from .base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( - aggregate_accuracy, -) +from .common import aggregate_accuracy -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import ( - subset_of, -) +from .fn_defs.subset_of import subset_of class SubsetOfScoringFn(BaseScoringFn):