tests w/ eval params

This commit is contained in:
Xi Yan 2024-11-11 16:01:21 -05:00
parent f8f95dad1f
commit ca2cd71182
2 changed files with 35 additions and 1 deletions

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
JUDGE_PROMPT = """
You will be given a question, a expected_answer, and a system_answer.
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
Provide your feedback as follows:
Feedback:::
Total rating: (your rating, as a int between 0 and 5)
Now here are the question, expected_answer, system_answer.
Question: {input_query}
Expected Answer: {expected_answer}
System Answer: {generated_answer}
Feedback:::
Total rating:
"""

View file

@ -19,9 +19,10 @@ from llama_stack.apis.eval.eval import (
EvalTaskDefWithProvider, EvalTaskDefWithProvider,
ModelCandidate, ModelCandidate,
) )
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from .constants import JUDGE_PROMPT
# How to run this test: # How to run this test:
# #
@ -65,6 +66,7 @@ class Testeval:
assert len(rows.rows) == 3 assert len(rows.rows) == 3
scoring_functions = [ scoring_functions = [
"meta-reference::llm_as_judge_base",
"meta-reference::equality", "meta-reference::equality",
] ]
task_id = "meta-reference::app_eval" task_id = "meta-reference::app_eval"
@ -84,10 +86,22 @@ class Testeval:
model="Llama3.2-3B-Instruct", model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
), ),
scoring_params={
"meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-8B-Instruct",
prompt_template=JUDGE_PROMPT,
judge_score_regexes=[
r"Total rating: (\d+)",
r"rating: (\d+)",
r"Rating: (\d+)",
],
)
},
), ),
) )
assert len(response.generations) == 3 assert len(response.generations) == 3
assert "meta-reference::equality" in response.scores assert "meta-reference::equality" in response.scores
assert "meta-reference::llm_as_judge_base" in response.scores
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack): async def test_eval_run_eval(self, eval_stack):