From 9b410a87bfefea0c3fc84311f8e497bce50d594e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 25 Oct 2024 17:03:01 -0700 Subject: [PATCH] extract score regex to llm context --- .../apis/scoring_functions/scoring_functions.py | 4 ++++ .../fn_defs/llm_as_judge_8b_correctness.json | 3 ++- .../scoring/scoring_fn/llm_as_judge_scoring_fn.py | 13 +++++++------ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index fc3584f90..2e5bf0aef 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -26,6 +26,10 @@ class Parameter(BaseModel): class LLMAsJudgeContext(BaseModel): judge_model: str prompt_template: Optional[str] = None + judge_score_regex: Optional[List[str]] = Field( + description="Regex to extract the score from the judge response", + default=None, + ) @json_schema_type diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json index 64d86a7ea..e33bc09ee 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json @@ -8,6 +8,7 @@ }, "context": { "judge_model": "Llama3.1-8B-Instruct", - "prompt_template": "\nYou will be given a question, a expected_answer, and a system_answer.\nYour task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.\nGive your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.\nProvide your feedback as follows:\nFeedback:::\nTotal rating: (your rating, as a int between 0 and 5)\nNow here are the question, expected_answer, system_answer.\nQuestion: {input_query}\nExpected Answer: {expected_answer}\nSystem Answer: {generated_answer}\nFeedback:::\nTotal rating:\n" + "prompt_template": "\nYou will be given a question, a expected_answer, and a system_answer.\nYour task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.\nGive your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.\nProvide your feedback as follows:\nFeedback:::\nTotal rating: (your rating, as a int between 0 and 5)\nNow here are the question, expected_answer, system_answer.\nQuestion: {input_query}\nExpected Answer: {expected_answer}\nSystem Answer: {generated_answer}\nFeedback:::\nTotal rating:\n", + "judge_score_regex": ["Total rating: (\\d+)", "rating: (\\d+)", "Rating: (\\d+)"] } } diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index 16672434f..bf3b8de17 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -37,8 +37,12 @@ class LlmAsJudgeScoringFn(BaseScoringFn): scoring_fn_identifier is not None ), "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}." assert ( - fn_def.context is not None and fn_def.context.prompt_template is not None + fn_def.context.prompt_template is not None + ), "LLM Judge prompt_template not found." + assert ( + fn_def.context.judge_score_regex is not None ), "LLM Judge prompt_template not found." input_query = input_row["input_query"] @@ -61,11 +65,8 @@ class LlmAsJudgeScoringFn(BaseScoringFn): ], ) content = judge_response.completion_message.content - rating_regexs = [ - r"Total rating: (\d+)", - r"rating: (\d+)", - r"Rating: (\d+)", - ] + rating_regexs = fn_def.context.judge_score_regex + judge_rating = None for regex in rating_regexs: match = re.search(regex, content)