From db32bc28d67d1311d84e91546a0aee8538778816 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Tue, 11 Mar 2025 21:59:00 -0700 Subject: [PATCH] refine --- .../inline/scoring/llm_as_judge/scoring.py | 8 +++++--- .../fn_defs/llm_as_judge_405b_math_match.py | 8 ++++---- .../scoring_fn/llm_as_judge_math_match_fn.py | 14 ++++++++------ llama_stack/templates/open-benchmark/run.yaml | 3 +++ 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 5b1715d9f..ffc178ca8 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -24,8 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import ( from .config import LlmAsJudgeScoringConfig from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn +from .scoring_fn.llm_as_judge_math_match_fn import LlmAsJudgeMathMatchFn -LLM_JUDGE_FN = LlmAsJudgeScoringFn +LLM_JUDGE_FN = [LlmAsJudgeScoringFn, LlmAsJudgeMathMatchFn] class LlmAsJudgeScoringImpl( @@ -45,8 +46,9 @@ class LlmAsJudgeScoringImpl( self.inference_api = inference_api async def initialize(self) -> None: - impl = LLM_JUDGE_FN(inference_api=self.inference_api) - self.llm_as_judge_fn = impl + for fn in LLM_JUDGE_FN: + impl = fn(inference_api=self.inference_api) + self.llm_as_judge_fn = impl async def shutdown(self) -> None: ... diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_math_match.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_math_match.py index 34746b3e8..fe0b463f7 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_math_match.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_math_match.py @@ -11,7 +11,7 @@ from llama_stack.apis.scoring_functions import ( ScoringFn, ) -EQUALITY_TEMPLATE = r""" +EQUALITY_TEMPLATE = """ Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications Examples: @@ -67,8 +67,8 @@ YOUR TASK Respond with only "Yes" or "No" (without quotes). Do not include a rationale. - Expression 1: %(expression1)s - Expression 2: %(expression2)s + Expression 1: {expression1} + Expression 2: {expression2} """.strip() @@ -79,7 +79,7 @@ llm_as_judge_405b_math_match = ScoringFn( provider_id="llm-as-judge", provider_resource_id="llm-as-judge-405b-math-match", params=LLMAsJudgeScoringFnParams( - judge_model="meta-llama/Llama-3.1-405B-Instruct", + judge_model="openai/gpt-4o", prompt_template=EQUALITY_TEMPLATE, aggregation_functions=[AggregationFunctionType.accuracy], ), diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_math_match_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_math_match_fn.py index 67c0a543f..286738b96 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_math_match_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_math_match_fn.py @@ -12,9 +12,9 @@ from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseSc from .fn_defs.llm_as_judge_405b_math_match import llm_as_judge_405b_math_match from .fn_defs.llm_as_judge_base import llm_as_judge_base +from ...basic.utils.math_utils import extract_result_from_boxed - -class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): +class LlmAsJudgeMathMatchFn(RegisteredBaseScoringFn): """ A scoring_fn that assigns """ @@ -47,8 +47,8 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): generated_answer = input_row["generated_answer"] judge_input_msg = fn_def.params.prompt_template.format( - expected_answer=expected_answer, - generated_answer=generated_answer, + expression1=expected_answer, + expression2=extract_result_from_boxed(generated_answer), ) print("judge_input_msg", judge_input_msg) @@ -62,9 +62,11 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): ], ) - score = 1.0 if judge_response.lower().strip() == "yes" else 0.0 + content = judge_response.completion_message.content + + score = 1.0 if content.lower().strip() == "yes" else 0.0 return { "score": score, - "judge_feedback": judge_response, + "judge_feedback": content, } diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 736b47746..28a5aefd3 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -219,6 +219,9 @@ benchmarks: - benchmark_id: meta-reference-math-500 dataset_id: math_500 scoring_functions: ["basic::regex_parser_math_response"] + - benchmark_id: meta-reference-math-500-llm-as-judge + dataset_id: math_500 + scoring_functions: ["llm-as-judge::405b-math-match"] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search