diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_math_match.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_math_match.py new file mode 100644 index 000000000..34746b3e8 --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_math_match.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + LLMAsJudgeScoringFnParams, + ScoringFn, +) + +EQUALITY_TEMPLATE = r""" +Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications + +Examples: + + Expression 1: $2x+3$ + Expression 2: $3+2x$ + +Yes + + Expression 1: 3/2 + Expression 2: 1.5 + +Yes + + Expression 1: $x^2+2x+1$ + Expression 2: $y^2+2y+1$ + +No + + Expression 1: $x^2+2x+1$ + Expression 2: $(x+1)^2$ + +Yes + + Expression 1: 3245/5 + Expression 2: 649 + +No +(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) + + Expression 1: 2/(-3) + Expression 2: -2/3 + +Yes +(trivial simplifications are allowed) + + Expression 1: 72 degrees + Expression 2: 72 + +Yes +(give benefit of the doubt to units) + + Expression 1: 64 + Expression 2: 64 square feet + +Yes +(give benefit of the doubt to units) + +--- + +YOUR TASK + + +Respond with only "Yes" or "No" (without quotes). Do not include a rationale. + + Expression 1: %(expression1)s + Expression 2: %(expression2)s +""".strip() + + +llm_as_judge_405b_math_match = ScoringFn( + identifier="llm-as-judge::405b-math-match", + description="Llm As Judge Scoring Function for Math Related Benchmark https://github.com/openai/simple-evals/blob/main/math_eval.py)", + return_type=NumberType(), + provider_id="llm-as-judge", + provider_resource_id="llm-as-judge-405b-math-match", + params=LLMAsJudgeScoringFnParams( + judge_model="meta-llama/Llama-3.1-405B-Instruct", + prompt_template=EQUALITY_TEMPLATE, + aggregation_functions=[AggregationFunctionType.accuracy], + ), +) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_math_match_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_math_match_fn.py new file mode 100644 index 000000000..67c0a543f --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_math_match_fn.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, Optional + +from llama_stack.apis.inference.inference import Inference, UserMessage +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFnParams +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn + +from .fn_defs.llm_as_judge_405b_math_match import llm_as_judge_405b_math_match +from .fn_defs.llm_as_judge_base import llm_as_judge_base + + +class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): + """ + A scoring_fn that assigns + """ + + def __init__(self, inference_api: Inference, *arg, **kwargs) -> None: + super().__init__(*arg, **kwargs) + self.inference_api = inference_api + self.supported_fn_defs_registry = { + llm_as_judge_base.identifier: llm_as_judge_base, + llm_as_judge_405b_math_match.identifier: llm_as_judge_405b_math_match, + } + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> ScoringResultRow: + assert scoring_fn_identifier is not None, "Scoring function identifier not found." + fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + + # override params if scoring_params is provided + if scoring_params is not None: + fn_def.params = scoring_params + + assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}." + assert fn_def.params.prompt_template is not None, "LLM Judge prompt_template not found." + + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + + judge_input_msg = fn_def.params.prompt_template.format( + expected_answer=expected_answer, + generated_answer=generated_answer, + ) + + print("judge_input_msg", judge_input_msg) + + judge_response = await self.inference_api.chat_completion( + model_id=fn_def.params.judge_model, + messages=[ + UserMessage( + content=judge_input_msg, + ), + ], + ) + + score = 1.0 if judge_response.lower().strip() == "yes" else 0.0 + + return { + "score": score, + "judge_feedback": judge_response, + }