mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
refactor scoring
This commit is contained in:
parent
aa66410f24
commit
a6038ffee9
9 changed files with 50 additions and 51 deletions
|
@ -17,7 +17,7 @@ from autoevals.llm import Factuality
|
||||||
from autoevals.ragas import AnswerCorrectness
|
from autoevals.ragas import AnswerCorrectness
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
|
|
||||||
from ..meta_reference.scoring.scoring_fn.common import aggregate_average
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
||||||
|
|
||||||
from .config import BraintrustScoringConfig
|
from .config import BraintrustScoringConfig
|
||||||
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
||||||
|
|
|
@ -9,7 +9,8 @@ from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
|
||||||
from .common import aggregate_accuracy
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
|
||||||
|
|
||||||
from .fn_defs.equality import equality
|
from .fn_defs.equality import equality
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,39 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
|
||||||
|
|
||||||
JUDGE_PROMPT = """
|
|
||||||
You will be given a question, a expected_answer, and a system_answer.
|
|
||||||
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
|
||||||
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
|
||||||
Provide your feedback as follows:
|
|
||||||
Feedback:::
|
|
||||||
Total rating: (your rating, as a int between 0 and 5)
|
|
||||||
Now here are the question, expected_answer, system_answer.
|
|
||||||
Question: {input_query}
|
|
||||||
Expected Answer: {expected_answer}
|
|
||||||
System Answer: {generated_answer}
|
|
||||||
Feedback:::
|
|
||||||
Total rating:
|
|
||||||
"""
|
|
||||||
|
|
||||||
llm_as_judge_8b_correctness = ScoringFnDef(
|
|
||||||
identifier="meta-reference::llm_as_judge_8b_correctness",
|
|
||||||
description="Llm As Judge Scoring Function",
|
|
||||||
return_type=NumberType(),
|
|
||||||
params=LLMAsJudgeScoringFnParams(
|
|
||||||
prompt_template=JUDGE_PROMPT,
|
|
||||||
judge_model="Llama3.1-8B-Instruct",
|
|
||||||
judge_score_regexes=[
|
|
||||||
r"Total rating: (\d+)",
|
|
||||||
r"rating: (\d+)",
|
|
||||||
r"Rating: (\d+)",
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
|
||||||
|
# JUDGE_PROMPT = """
|
||||||
|
# You will be given a question, a expected_answer, and a system_answer.
|
||||||
|
# Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
||||||
|
# Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
||||||
|
# Provide your feedback as follows:
|
||||||
|
# Feedback:::
|
||||||
|
# Total rating: (your rating, as a int between 0 and 5)
|
||||||
|
# Now here are the question, expected_answer, system_answer.
|
||||||
|
# Question: {input_query}
|
||||||
|
# Expected Answer: {expected_answer}
|
||||||
|
# System Answer: {generated_answer}
|
||||||
|
# Feedback:::
|
||||||
|
# Total rating:
|
||||||
|
# """
|
||||||
|
|
||||||
|
llm_as_judge_base = ScoringFnDef(
|
||||||
|
identifier="meta-reference::llm_as_judge_base",
|
||||||
|
description="Llm As Judge Scoring Function",
|
||||||
|
return_type=NumberType(),
|
||||||
|
# params=LLMAsJudgeScoringFnParams(
|
||||||
|
# prompt_template=JUDGE_PROMPT,
|
||||||
|
# judge_model="Llama3.1-8B-Instruct",
|
||||||
|
# judge_score_regexes=[
|
||||||
|
# r"Total rating: (\d+)",
|
||||||
|
# r"rating: (\d+)",
|
||||||
|
# r"Rating: (\d+)",
|
||||||
|
# ],
|
||||||
|
# ),
|
||||||
|
)
|
|
@ -11,8 +11,9 @@ from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from .common import aggregate_average
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
||||||
from .fn_defs.llm_as_judge_8b_correctness import llm_as_judge_8b_correctness
|
|
||||||
|
from .fn_defs.llm_as_judge_base import llm_as_judge_base
|
||||||
|
|
||||||
|
|
||||||
class LlmAsJudgeScoringFn(BaseScoringFn):
|
class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
|
@ -24,7 +25,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
super().__init__(*arg, **kwargs)
|
super().__init__(*arg, **kwargs)
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.supported_fn_defs_registry = {
|
self.supported_fn_defs_registry = {
|
||||||
llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness,
|
llm_as_judge_base.identifier: llm_as_judge_base,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def score_row(
|
async def score_row(
|
||||||
|
|
|
@ -9,7 +9,7 @@ from .base_scoring_fn import BaseScoringFn
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from .common import aggregate_accuracy
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
|
||||||
|
|
||||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||||
regex_parser_multiple_choice_answer,
|
regex_parser_multiple_choice_answer,
|
||||||
|
|
|
@ -8,7 +8,7 @@ from .base_scoring_fn import BaseScoringFn
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from .common import aggregate_accuracy
|
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
|
||||||
|
|
||||||
from .fn_defs.subset_of import subset_of
|
from .fn_defs.subset_of import subset_of
|
||||||
|
|
||||||
|
|
|
@ -61,9 +61,9 @@ class TestScoring:
|
||||||
assert len(rows.rows) == 3
|
assert len(rows.rows) == 3
|
||||||
|
|
||||||
scoring_functions = {
|
scoring_functions = {
|
||||||
"meta-reference::llm_as_judge_8b_correctness": None,
|
|
||||||
"meta-reference::equality": None,
|
"meta-reference::equality": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await scoring_impl.score(
|
response = await scoring_impl.score(
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
|
@ -116,7 +116,7 @@ class TestScoring:
|
||||||
assert len(rows.rows) == 3
|
assert len(rows.rows) == 3
|
||||||
|
|
||||||
scoring_functions = {
|
scoring_functions = {
|
||||||
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
|
"meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams(
|
||||||
judge_model="Llama3.1-405B-Instruct",
|
judge_model="Llama3.1-405B-Instruct",
|
||||||
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
||||||
judge_score_regexes=[r"Score: (\d+)"],
|
judge_score_regexes=[r"Score: (\d+)"],
|
||||||
|
|
|
@ -3,13 +3,10 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
|
||||||
FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
|
|
||||||
|
|
||||||
|
|
||||||
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||||
num_correct = sum(result["score"] for result in scoring_results)
|
num_correct = sum(result["score"] for result in scoring_results)
|
Loading…
Add table
Add a link
Reference in a new issue