This commit is contained in:
Botao Chen 2025-03-11 21:59:00 -07:00
parent 8eff285017
commit db32bc28d6
4 changed files with 20 additions and 13 deletions

View file

@ -24,8 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
from .scoring_fn.llm_as_judge_math_match_fn import LlmAsJudgeMathMatchFn
LLM_JUDGE_FN = LlmAsJudgeScoringFn LLM_JUDGE_FN = [LlmAsJudgeScoringFn, LlmAsJudgeMathMatchFn]
class LlmAsJudgeScoringImpl( class LlmAsJudgeScoringImpl(
@ -45,8 +46,9 @@ class LlmAsJudgeScoringImpl(
self.inference_api = inference_api self.inference_api = inference_api
async def initialize(self) -> None: async def initialize(self) -> None:
impl = LLM_JUDGE_FN(inference_api=self.inference_api) for fn in LLM_JUDGE_FN:
self.llm_as_judge_fn = impl impl = fn(inference_api=self.inference_api)
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...

View file

@ -11,7 +11,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFn, ScoringFn,
) )
EQUALITY_TEMPLATE = r""" EQUALITY_TEMPLATE = """
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
Examples: Examples:
@ -67,8 +67,8 @@ YOUR TASK
Respond with only "Yes" or "No" (without quotes). Do not include a rationale. Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
Expression 1: %(expression1)s Expression 1: {expression1}
Expression 2: %(expression2)s Expression 2: {expression2}
""".strip() """.strip()
@ -79,7 +79,7 @@ llm_as_judge_405b_math_match = ScoringFn(
provider_id="llm-as-judge", provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-405b-math-match", provider_resource_id="llm-as-judge-405b-math-match",
params=LLMAsJudgeScoringFnParams( params=LLMAsJudgeScoringFnParams(
judge_model="meta-llama/Llama-3.1-405B-Instruct", judge_model="openai/gpt-4o",
prompt_template=EQUALITY_TEMPLATE, prompt_template=EQUALITY_TEMPLATE,
aggregation_functions=[AggregationFunctionType.accuracy], aggregation_functions=[AggregationFunctionType.accuracy],
), ),

View file

@ -12,9 +12,9 @@ from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseSc
from .fn_defs.llm_as_judge_405b_math_match import llm_as_judge_405b_math_match from .fn_defs.llm_as_judge_405b_math_match import llm_as_judge_405b_math_match
from .fn_defs.llm_as_judge_base import llm_as_judge_base from .fn_defs.llm_as_judge_base import llm_as_judge_base
from ...basic.utils.math_utils import extract_result_from_boxed
class LlmAsJudgeMathMatchFn(RegisteredBaseScoringFn):
class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that assigns A scoring_fn that assigns
""" """
@ -47,8 +47,8 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]
judge_input_msg = fn_def.params.prompt_template.format( judge_input_msg = fn_def.params.prompt_template.format(
expected_answer=expected_answer, expression1=expected_answer,
generated_answer=generated_answer, expression2=extract_result_from_boxed(generated_answer),
) )
print("judge_input_msg", judge_input_msg) print("judge_input_msg", judge_input_msg)
@ -62,9 +62,11 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
], ],
) )
score = 1.0 if judge_response.lower().strip() == "yes" else 0.0 content = judge_response.completion_message.content
score = 1.0 if content.lower().strip() == "yes" else 0.0
return { return {
"score": score, "score": score,
"judge_feedback": judge_response, "judge_feedback": content,
} }

View file

@ -219,6 +219,9 @@ benchmarks:
- benchmark_id: meta-reference-math-500 - benchmark_id: meta-reference-math-500
dataset_id: math_500 dataset_id: math_500
scoring_functions: ["basic::regex_parser_math_response"] scoring_functions: ["basic::regex_parser_math_response"]
- benchmark_id: meta-reference-math-500-llm-as-judge
dataset_id: math_500
scoring_functions: ["llm-as-judge::405b-math-match"]
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search