init commit

This commit is contained in:
Botao Chen 2025-03-10 20:24:04 -07:00
parent ade76e4a69
commit bde99482c6
2 changed files with 156 additions and 0 deletions

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
LLMAsJudgeScoringFnParams,
ScoringFn,
)
EQUALITY_TEMPLATE = r"""
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
Examples:
Expression 1: $2x+3$
Expression 2: $3+2x$
Yes
Expression 1: 3/2
Expression 2: 1.5
Yes
Expression 1: $x^2+2x+1$
Expression 2: $y^2+2y+1$
No
Expression 1: $x^2+2x+1$
Expression 2: $(x+1)^2$
Yes
Expression 1: 3245/5
Expression 2: 649
No
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
Expression 1: 2/(-3)
Expression 2: -2/3
Yes
(trivial simplifications are allowed)
Expression 1: 72 degrees
Expression 2: 72
Yes
(give benefit of the doubt to units)
Expression 1: 64
Expression 2: 64 square feet
Yes
(give benefit of the doubt to units)
---
YOUR TASK
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
Expression 1: %(expression1)s
Expression 2: %(expression2)s
""".strip()
llm_as_judge_405b_math_match = ScoringFn(
identifier="llm-as-judge::405b-math-match",
description="Llm As Judge Scoring Function for Math Related Benchmark https://github.com/openai/simple-evals/blob/main/math_eval.py)",
return_type=NumberType(),
provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-405b-math-match",
params=LLMAsJudgeScoringFnParams(
judge_model="meta-llama/Llama-3.1-405B-Instruct",
prompt_template=EQUALITY_TEMPLATE,
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference, UserMessage
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.llm_as_judge_405b_math_match import llm_as_judge_405b_math_match
from .fn_defs.llm_as_judge_base import llm_as_judge_base
class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns
"""
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.inference_api = inference_api
self.supported_fn_defs_registry = {
llm_as_judge_base.identifier: llm_as_judge_base,
llm_as_judge_405b_math_match.identifier: llm_as_judge_405b_math_match,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
# override params if scoring_params is provided
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert fn_def.params.prompt_template is not None, "LLM Judge prompt_template not found."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
judge_input_msg = fn_def.params.prompt_template.format(
expected_answer=expected_answer,
generated_answer=generated_answer,
)
print("judge_input_msg", judge_input_msg)
judge_response = await self.inference_api.chat_completion(
model_id=fn_def.params.judge_model,
messages=[
UserMessage(
content=judge_input_msg,
),
],
)
score = 1.0 if judge_response.lower().strip() == "yes" else 0.0
return {
"score": score,
"judge_feedback": judge_response,
}