refactor scoring

This commit is contained in:
Xi Yan 2024-11-11 15:48:07 -05:00
parent aa66410f24
commit a6038ffee9
9 changed files with 50 additions and 51 deletions

View file

@ -17,7 +17,7 @@ from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from ..meta_reference.scoring.scoring_fn.common import aggregate_average
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def

View file

@ -9,7 +9,8 @@ from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from .common import aggregate_accuracy
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.equality import equality

View file

@ -1,39 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
JUDGE_PROMPT = """
You will be given a question, a expected_answer, and a system_answer.
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
Provide your feedback as follows:
Feedback:::
Total rating: (your rating, as a int between 0 and 5)
Now here are the question, expected_answer, system_answer.
Question: {input_query}
Expected Answer: {expected_answer}
System Answer: {generated_answer}
Feedback:::
Total rating:
"""
llm_as_judge_8b_correctness = ScoringFnDef(
identifier="meta-reference::llm_as_judge_8b_correctness",
description="Llm As Judge Scoring Function",
return_type=NumberType(),
params=LLMAsJudgeScoringFnParams(
prompt_template=JUDGE_PROMPT,
judge_model="Llama3.1-8B-Instruct",
judge_score_regexes=[
r"Total rating: (\d+)",
r"rating: (\d+)",
r"Rating: (\d+)",
],
),
)

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
# JUDGE_PROMPT = """
# You will be given a question, a expected_answer, and a system_answer.
# Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
# Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
# Provide your feedback as follows:
# Feedback:::
# Total rating: (your rating, as a int between 0 and 5)
# Now here are the question, expected_answer, system_answer.
# Question: {input_query}
# Expected Answer: {expected_answer}
# System Answer: {generated_answer}
# Feedback:::
# Total rating:
# """
llm_as_judge_base = ScoringFnDef(
identifier="meta-reference::llm_as_judge_base",
description="Llm As Judge Scoring Function",
return_type=NumberType(),
# params=LLMAsJudgeScoringFnParams(
# prompt_template=JUDGE_PROMPT,
# judge_model="Llama3.1-8B-Instruct",
# judge_score_regexes=[
# r"Total rating: (\d+)",
# r"rating: (\d+)",
# r"Rating: (\d+)",
# ],
# ),
)

View file

@ -11,8 +11,9 @@ from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
import re
from .common import aggregate_average
from .fn_defs.llm_as_judge_8b_correctness import llm_as_judge_8b_correctness
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
from .fn_defs.llm_as_judge_base import llm_as_judge_base
class LlmAsJudgeScoringFn(BaseScoringFn):
@ -24,7 +25,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
super().__init__(*arg, **kwargs)
self.inference_api = inference_api
self.supported_fn_defs_registry = {
llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness,
llm_as_judge_base.identifier: llm_as_judge_base,
}
async def score_row(

View file

@ -9,7 +9,7 @@ from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from .common import aggregate_accuracy
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,

View file

@ -8,7 +8,7 @@ from .base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from .common import aggregate_accuracy
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.subset_of import subset_of

View file

@ -61,9 +61,9 @@ class TestScoring:
assert len(rows.rows) == 3
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": None,
"meta-reference::equality": None,
}
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
@ -116,7 +116,7 @@ class TestScoring:
assert len(rows.rows) == 3
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
"meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
judge_score_regexes=[r"Score: (\d+)"],

View file

@ -3,13 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow
FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
num_correct = sum(result["score"] for result in scoring_results)