diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index b0683dd04..57723bb47 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -17,7 +17,7 @@ from autoevals.llm import Factuality from autoevals.ragas import AnswerCorrectness from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from ..meta_reference.scoring.scoring_fn.common import aggregate_average +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average from .config import BraintrustScoringConfig from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/equality_scoring_fn.py index 89e516663..877b64e4e 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/equality_scoring_fn.py @@ -9,7 +9,8 @@ from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 -from .common import aggregate_accuracy +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy + from .fn_defs.equality import equality diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py deleted file mode 100644 index 68d77b8df..000000000 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 -from llama_stack.apis.common.type_system import NumberType - -JUDGE_PROMPT = """ -You will be given a question, a expected_answer, and a system_answer. -Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question. -Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question. -Provide your feedback as follows: -Feedback::: -Total rating: (your rating, as a int between 0 and 5) -Now here are the question, expected_answer, system_answer. -Question: {input_query} -Expected Answer: {expected_answer} -System Answer: {generated_answer} -Feedback::: -Total rating: -""" - -llm_as_judge_8b_correctness = ScoringFnDef( - identifier="meta-reference::llm_as_judge_8b_correctness", - description="Llm As Judge Scoring Function", - return_type=NumberType(), - params=LLMAsJudgeScoringFnParams( - prompt_template=JUDGE_PROMPT, - judge_model="Llama3.1-8B-Instruct", - judge_score_regexes=[ - r"Total rating: (\d+)", - r"rating: (\d+)", - r"Rating: (\d+)", - ], - ), -) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py new file mode 100644 index 000000000..f7de54f46 --- /dev/null +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import NumberType + +# JUDGE_PROMPT = """ +# You will be given a question, a expected_answer, and a system_answer. +# Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question. +# Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question. +# Provide your feedback as follows: +# Feedback::: +# Total rating: (your rating, as a int between 0 and 5) +# Now here are the question, expected_answer, system_answer. +# Question: {input_query} +# Expected Answer: {expected_answer} +# System Answer: {generated_answer} +# Feedback::: +# Total rating: +# """ + +llm_as_judge_base = ScoringFnDef( + identifier="meta-reference::llm_as_judge_base", + description="Llm As Judge Scoring Function", + return_type=NumberType(), + # params=LLMAsJudgeScoringFnParams( + # prompt_template=JUDGE_PROMPT, + # judge_model="Llama3.1-8B-Instruct", + # judge_score_regexes=[ + # r"Total rating: (\d+)", + # r"rating: (\d+)", + # r"Rating: (\d+)", + # ], + # ), +) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/llm_as_judge_scoring_fn.py index 24bdc6400..e1f19e640 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/llm_as_judge_scoring_fn.py @@ -11,8 +11,9 @@ from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 import re -from .common import aggregate_average -from .fn_defs.llm_as_judge_8b_correctness import llm_as_judge_8b_correctness +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average + +from .fn_defs.llm_as_judge_base import llm_as_judge_base class LlmAsJudgeScoringFn(BaseScoringFn): @@ -24,7 +25,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn): super().__init__(*arg, **kwargs) self.inference_api = inference_api self.supported_fn_defs_registry = { - llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness, + llm_as_judge_base.identifier: llm_as_judge_base, } async def score_row( diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py index 0aff2f535..3cbc6cbe4 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py @@ -9,7 +9,7 @@ from .base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 -from .common import aggregate_accuracy +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy from .fn_defs.regex_parser_multiple_choice_answer import ( regex_parser_multiple_choice_answer, diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/subset_of_scoring_fn.py index d484e182c..fe5988160 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/subset_of_scoring_fn.py @@ -8,7 +8,7 @@ from .base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 -from .common import aggregate_accuracy +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy from .fn_defs.subset_of import subset_of diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 170073eeb..1f2608f3b 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -61,9 +61,9 @@ class TestScoring: assert len(rows.rows) == 3 scoring_functions = { - "meta-reference::llm_as_judge_8b_correctness": None, "meta-reference::equality": None, } + response = await scoring_impl.score( input_rows=rows.rows, scoring_functions=scoring_functions, @@ -116,7 +116,7 @@ class TestScoring: assert len(rows.rows) == 3 scoring_functions = { - "meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams( + "meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams( judge_model="Llama3.1-405B-Instruct", prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", judge_score_regexes=[r"Score: (\d+)"], diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/common.py b/llama_stack/providers/utils/scoring/aggregation_utils.py similarity index 92% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/common.py rename to llama_stack/providers/utils/scoring/aggregation_utils.py index 25bac5edc..1ca0c7fb3 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/common.py +++ b/llama_stack/providers/utils/scoring/aggregation_utils.py @@ -3,13 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pathlib import Path from typing import Any, Dict, List from llama_stack.apis.scoring import ScoringResultRow -FN_DEFS_PATH = Path(__file__).parent / "fn_defs" - def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: num_correct = sum(result["score"] for result in scoring_results)