mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
init commit
This commit is contained in:
parent
ade76e4a69
commit
bde99482c6
2 changed files with 156 additions and 0 deletions
|
@ -0,0 +1,86 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
LLMAsJudgeScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
EQUALITY_TEMPLATE = r"""
|
||||||
|
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Expression 1: $2x+3$
|
||||||
|
Expression 2: $3+2x$
|
||||||
|
|
||||||
|
Yes
|
||||||
|
|
||||||
|
Expression 1: 3/2
|
||||||
|
Expression 2: 1.5
|
||||||
|
|
||||||
|
Yes
|
||||||
|
|
||||||
|
Expression 1: $x^2+2x+1$
|
||||||
|
Expression 2: $y^2+2y+1$
|
||||||
|
|
||||||
|
No
|
||||||
|
|
||||||
|
Expression 1: $x^2+2x+1$
|
||||||
|
Expression 2: $(x+1)^2$
|
||||||
|
|
||||||
|
Yes
|
||||||
|
|
||||||
|
Expression 1: 3245/5
|
||||||
|
Expression 2: 649
|
||||||
|
|
||||||
|
No
|
||||||
|
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
||||||
|
|
||||||
|
Expression 1: 2/(-3)
|
||||||
|
Expression 2: -2/3
|
||||||
|
|
||||||
|
Yes
|
||||||
|
(trivial simplifications are allowed)
|
||||||
|
|
||||||
|
Expression 1: 72 degrees
|
||||||
|
Expression 2: 72
|
||||||
|
|
||||||
|
Yes
|
||||||
|
(give benefit of the doubt to units)
|
||||||
|
|
||||||
|
Expression 1: 64
|
||||||
|
Expression 2: 64 square feet
|
||||||
|
|
||||||
|
Yes
|
||||||
|
(give benefit of the doubt to units)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
YOUR TASK
|
||||||
|
|
||||||
|
|
||||||
|
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
||||||
|
|
||||||
|
Expression 1: %(expression1)s
|
||||||
|
Expression 2: %(expression2)s
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
llm_as_judge_405b_math_match = ScoringFn(
|
||||||
|
identifier="llm-as-judge::405b-math-match",
|
||||||
|
description="Llm As Judge Scoring Function for Math Related Benchmark https://github.com/openai/simple-evals/blob/main/math_eval.py)",
|
||||||
|
return_type=NumberType(),
|
||||||
|
provider_id="llm-as-judge",
|
||||||
|
provider_resource_id="llm-as-judge-405b-math-match",
|
||||||
|
params=LLMAsJudgeScoringFnParams(
|
||||||
|
judge_model="meta-llama/Llama-3.1-405B-Instruct",
|
||||||
|
prompt_template=EQUALITY_TEMPLATE,
|
||||||
|
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||||
|
),
|
||||||
|
)
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.inference.inference import Inference, UserMessage
|
||||||
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
|
from .fn_defs.llm_as_judge_405b_math_match import llm_as_judge_405b_math_match
|
||||||
|
from .fn_defs.llm_as_judge_base import llm_as_judge_base
|
||||||
|
|
||||||
|
|
||||||
|
class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn that assigns
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
|
||||||
|
super().__init__(*arg, **kwargs)
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
llm_as_judge_base.identifier: llm_as_judge_base,
|
||||||
|
llm_as_judge_405b_math_match.identifier: llm_as_judge_405b_math_match,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||||
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
|
|
||||||
|
# override params if scoring_params is provided
|
||||||
|
if scoring_params is not None:
|
||||||
|
fn_def.params = scoring_params
|
||||||
|
|
||||||
|
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
|
||||||
|
assert fn_def.params.prompt_template is not None, "LLM Judge prompt_template not found."
|
||||||
|
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
||||||
|
judge_input_msg = fn_def.params.prompt_template.format(
|
||||||
|
expected_answer=expected_answer,
|
||||||
|
generated_answer=generated_answer,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("judge_input_msg", judge_input_msg)
|
||||||
|
|
||||||
|
judge_response = await self.inference_api.chat_completion(
|
||||||
|
model_id=fn_def.params.judge_model,
|
||||||
|
messages=[
|
||||||
|
UserMessage(
|
||||||
|
content=judge_input_msg,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
score = 1.0 if judge_response.lower().strip() == "yes" else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"score": score,
|
||||||
|
"judge_feedback": judge_response,
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue