mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
registration answer parsing
This commit is contained in:
parent
f1a2548ad5
commit
779e66f83f
10 changed files with 178 additions and 45 deletions
|
@ -4,7 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
runtime_checkable,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -42,7 +51,7 @@ class AnswerParsingContext(BaseModel):
|
||||||
ScoringContextType.answer_parsing.value
|
ScoringContextType.answer_parsing.value
|
||||||
)
|
)
|
||||||
parsing_regex: Optional[List[str]] = Field(
|
parsing_regex: Optional[List[str]] = Field(
|
||||||
"Regex to extract the answer from generated response",
|
description="Regex to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -67,7 +76,10 @@ class ScoringFnDef(BaseModel):
|
||||||
return_type: ParamType = Field(
|
return_type: ParamType = Field(
|
||||||
description="The return type of the deterministic function",
|
description="The return type of the deterministic function",
|
||||||
)
|
)
|
||||||
context: Optional[ScoringContext] = None
|
context: Optional[ScoringContext] = Field(
|
||||||
|
description="Scoring function context used different answer extraction",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
# We can optionally add information here to support packaging of code, etc.
|
# We can optionally add information here to support packaging of code, etc.
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,23 +13,20 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.inference.inference import Inference
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
|
||||||
EqualityScoringFn,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
|
||||||
LlmAsJudgeScoringFn,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
|
||||||
SubsetOfScoringFn,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import MetaReferenceScoringConfig
|
from .config import MetaReferenceScoringConfig
|
||||||
|
from .scoring_fn.answer_parsing_scoring_fn import AnswerParsingScoringFn
|
||||||
|
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||||
|
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||||
|
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||||
|
|
||||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
|
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
|
||||||
|
|
||||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
# Scoring functions with context that can be registered
|
||||||
|
REGISTERABLE_SCORING_FNS = {
|
||||||
|
ScoringContextType.llm_as_judge.value: LlmAsJudgeScoringFn,
|
||||||
|
ScoringContextType.answer_parsing.value: AnswerParsingScoringFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
@ -44,18 +41,24 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
# keep track of scoring function id to impls
|
||||||
self.scoring_fn_id_impls = {}
|
self.scoring_fn_id_impls = {}
|
||||||
|
# registerable scoring fn context to impls
|
||||||
|
self.registerable_scoring_fn_impls = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
for x in FIXED_FNS:
|
for x in FIXED_FNS:
|
||||||
impl = x()
|
impl = x()
|
||||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
for x in LLM_JUDGE_FNS:
|
for context_type, impl_cls in REGISTERABLE_SCORING_FNS.items():
|
||||||
impl = x(inference_api=self.inference_api)
|
if context_type == ScoringContextType.llm_as_judge.value:
|
||||||
|
impl = impl_cls(inference_api=self.inference_api)
|
||||||
|
else:
|
||||||
|
impl = impl_cls()
|
||||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
self.llm_as_judge_fn = impl
|
self.registerable_scoring_fn_impls[context_type] = impl
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
@ -74,8 +77,12 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
return scoring_fn_defs_list
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
assert (
|
||||||
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
|
function_def.context is not None
|
||||||
|
), "Only ScoringFnDef with context set can be registered"
|
||||||
|
fn_impl = self.registerable_scoring_fn_impls[function_def.context.type]
|
||||||
|
fn_impl.register_scoring_fn_def(function_def)
|
||||||
|
self.scoring_fn_id_impls[function_def.identifier] = fn_impl
|
||||||
|
|
||||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import re
|
||||||
|
|
||||||
|
from .base_scoring_fn import BaseScoringFn
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from .common import aggregate_accuracy
|
||||||
|
|
||||||
|
from .fn_defs.answer_parsing_multiple_choice import answer_parsing_multiple_choice
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerParsingScoringFn(BaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
answer_parsing_multiple_choice.identifier: answer_parsing_multiple_choice,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
assert (
|
||||||
|
scoring_fn_identifier is not None
|
||||||
|
), "Scoring function identifier not found."
|
||||||
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
|
assert (
|
||||||
|
fn_def.context is not None
|
||||||
|
and fn_def.context.type == ScoringContextType.answer_parsing.value
|
||||||
|
), f"AnswerParsingContext not found for {fn_def}."
|
||||||
|
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
||||||
|
# parse answer according to regex
|
||||||
|
parsed_answer = None
|
||||||
|
for regex in fn_def.context.parsing_regex:
|
||||||
|
match = re.search(regex, generated_answer)
|
||||||
|
if match:
|
||||||
|
parsed_answer = match.group(1)
|
||||||
|
break
|
||||||
|
|
||||||
|
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
|
||||||
|
return {
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return aggregate_accuracy(scoring_results)
|
|
@ -15,9 +15,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import
|
||||||
aggregate_accuracy,
|
aggregate_accuracy,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
|
from .fn_defs.equality import equality
|
||||||
equality,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EqualityScoringFn(BaseScoringFn):
|
class EqualityScoringFn(BaseScoringFn):
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
|
||||||
|
MULTILINGUAL_ANSWER_REGEXES = [
|
||||||
|
r"Answer\s*:",
|
||||||
|
r"Answer\s*:", # Korean invisible character
|
||||||
|
r"উত্তর\s*:",
|
||||||
|
r"उत्तर\s*:",
|
||||||
|
r"উত্তরঃ",
|
||||||
|
r"উত্তর\s*:",
|
||||||
|
r"Antwort\s*:",
|
||||||
|
r"답변\s*:",
|
||||||
|
r"정답\s*:",
|
||||||
|
r"답\s*:",
|
||||||
|
r"答案\s*:",
|
||||||
|
r"答案\s*:",
|
||||||
|
r"答\s*:",
|
||||||
|
r"答\s*:",
|
||||||
|
r"答复\s*:",
|
||||||
|
r"答曰\s*:",
|
||||||
|
r"الإجابة:",
|
||||||
|
r"الجواب:",
|
||||||
|
r"إجابة:",
|
||||||
|
r"الإجابة النهائية:",
|
||||||
|
r"الإجابة الصحيحة:",
|
||||||
|
r"الإجابة الصحيحة هي:",
|
||||||
|
r"الإجابة هي:",
|
||||||
|
r"Respuesta\s*:",
|
||||||
|
r"Risposta\s*:",
|
||||||
|
r"答え\s*:",
|
||||||
|
r"答え\s*:",
|
||||||
|
r"回答\s*:",
|
||||||
|
r"回答\s*:",
|
||||||
|
r"解答\s*:",
|
||||||
|
r"Jawaban\s*:",
|
||||||
|
r"Réponse\s*:",
|
||||||
|
r"Resposta\s*:",
|
||||||
|
r"Jibu\s*:",
|
||||||
|
r"Idahun\s*:",
|
||||||
|
r"Ìdáhùn\s*:",
|
||||||
|
r"Idáhùn\s*:",
|
||||||
|
r"Àmọ̀nà\s*:",
|
||||||
|
r"Àdáhùn\s*:",
|
||||||
|
r"Ànúgọ\s*:",
|
||||||
|
r"Àṣàyàn\s*:",
|
||||||
|
]
|
||||||
|
|
||||||
|
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
||||||
|
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||||
|
)
|
||||||
|
|
||||||
|
answer_parsing_multiple_choice = ScoringFnDef(
|
||||||
|
identifier="meta-reference::answer_parsing_multiple_choice",
|
||||||
|
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
|
||||||
|
return_type=NumberType(),
|
||||||
|
context=AnswerParsingContext(
|
||||||
|
parsing_regex=[
|
||||||
|
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||||
|
for x in MULTILINGUAL_ANSWER_REGEXES
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
|
@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
equality = ScoringFnDef(
|
equality = ScoringFnDef(
|
||||||
identifier="meta-reference::equality",
|
identifier="meta-reference::equality",
|
||||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,6 @@ Total rating:
|
||||||
llm_as_judge_8b_correctness = ScoringFnDef(
|
llm_as_judge_8b_correctness = ScoringFnDef(
|
||||||
identifier="meta-reference::llm_as_judge_8b_correctness",
|
identifier="meta-reference::llm_as_judge_8b_correctness",
|
||||||
description="Llm As Judge Scoring Function",
|
description="Llm As Judge Scoring Function",
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
context=LLMAsJudgeContext(
|
context=LLMAsJudgeContext(
|
||||||
prompt_template=JUDGE_PROMPT,
|
prompt_template=JUDGE_PROMPT,
|
||||||
|
|
|
@ -11,6 +11,5 @@ from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
subset_of = ScoringFnDef(
|
subset_of = ScoringFnDef(
|
||||||
identifier="meta-reference::subset_of",
|
identifier="meta-reference::subset_of",
|
||||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,25 +4,20 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from llama_stack.apis.inference.inference import Inference
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
|
||||||
BaseScoringFn,
|
from .base_scoring_fn import BaseScoringFn
|
||||||
)
|
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
from .common import aggregate_average
|
||||||
aggregate_average,
|
from .fn_defs.llm_as_judge_8b_correctness import llm_as_judge_8b_correctness
|
||||||
)
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
|
|
||||||
llm_as_judge_8b_correctness,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LlmAsJudgeScoringFn(BaseScoringFn):
|
class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
"""
|
"""
|
||||||
A scoring_fn that assigns
|
A scoring_fn using LLM as Judge to produce score
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
|
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
|
||||||
|
|
|
@ -4,19 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
from .base_scoring_fn import BaseScoringFn
|
||||||
BaseScoringFn,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
from .common import aggregate_accuracy
|
||||||
aggregate_accuracy,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
|
from .fn_defs.subset_of import subset_of
|
||||||
subset_of,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SubsetOfScoringFn(BaseScoringFn):
|
class SubsetOfScoringFn(BaseScoringFn):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue