diff --git a/llama_stack/providers/tests/eval/constants.py b/llama_stack/providers/tests/eval/constants.py new file mode 100644 index 000000000..0fb1a44c4 --- /dev/null +++ b/llama_stack/providers/tests/eval/constants.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +JUDGE_PROMPT = """ +You will be given a question, a expected_answer, and a system_answer. +Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question. +Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question. +Provide your feedback as follows: +Feedback::: +Total rating: (your rating, as a int between 0 and 5) +Now here are the question, expected_answer, system_answer. +Question: {input_query} +Expected Answer: {expected_answer} +System Answer: {generated_answer} +Feedback::: +Total rating: +""" diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index bdd5c8de0..9f14c61ef 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -19,9 +19,10 @@ from llama_stack.apis.eval.eval import ( EvalTaskDefWithProvider, ModelCandidate, ) +from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset - +from .constants import JUDGE_PROMPT # How to run this test: # @@ -65,6 +66,7 @@ class Testeval: assert len(rows.rows) == 3 scoring_functions = [ + "meta-reference::llm_as_judge_base", "meta-reference::equality", ] task_id = "meta-reference::app_eval" @@ -84,10 +86,22 @@ class Testeval: model="Llama3.2-3B-Instruct", sampling_params=SamplingParams(), ), + scoring_params={ + "meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams( + judge_model="Llama3.1-8B-Instruct", + prompt_template=JUDGE_PROMPT, + judge_score_regexes=[ + r"Total rating: (\d+)", + r"rating: (\d+)", + r"Rating: (\d+)", + ], + ) + }, ), ) assert len(response.generations) == 3 assert "meta-reference::equality" in response.scores + assert "meta-reference::llm_as_judge_base" in response.scores @pytest.mark.asyncio async def test_eval_run_eval(self, eval_stack):