mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
init commit
This commit is contained in:
parent
ade76e4a69
commit
bde99482c6
2 changed files with 156 additions and 0 deletions
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
LLMAsJudgeScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
EQUALITY_TEMPLATE = r"""
|
||||
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
||||
|
||||
Examples:
|
||||
|
||||
Expression 1: $2x+3$
|
||||
Expression 2: $3+2x$
|
||||
|
||||
Yes
|
||||
|
||||
Expression 1: 3/2
|
||||
Expression 2: 1.5
|
||||
|
||||
Yes
|
||||
|
||||
Expression 1: $x^2+2x+1$
|
||||
Expression 2: $y^2+2y+1$
|
||||
|
||||
No
|
||||
|
||||
Expression 1: $x^2+2x+1$
|
||||
Expression 2: $(x+1)^2$
|
||||
|
||||
Yes
|
||||
|
||||
Expression 1: 3245/5
|
||||
Expression 2: 649
|
||||
|
||||
No
|
||||
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
||||
|
||||
Expression 1: 2/(-3)
|
||||
Expression 2: -2/3
|
||||
|
||||
Yes
|
||||
(trivial simplifications are allowed)
|
||||
|
||||
Expression 1: 72 degrees
|
||||
Expression 2: 72
|
||||
|
||||
Yes
|
||||
(give benefit of the doubt to units)
|
||||
|
||||
Expression 1: 64
|
||||
Expression 2: 64 square feet
|
||||
|
||||
Yes
|
||||
(give benefit of the doubt to units)
|
||||
|
||||
---
|
||||
|
||||
YOUR TASK
|
||||
|
||||
|
||||
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
||||
|
||||
Expression 1: %(expression1)s
|
||||
Expression 2: %(expression2)s
|
||||
""".strip()
|
||||
|
||||
|
||||
llm_as_judge_405b_math_match = ScoringFn(
|
||||
identifier="llm-as-judge::405b-math-match",
|
||||
description="Llm As Judge Scoring Function for Math Related Benchmark https://github.com/openai/simple-evals/blob/main/math_eval.py)",
|
||||
return_type=NumberType(),
|
||||
provider_id="llm-as-judge",
|
||||
provider_resource_id="llm-as-judge-405b-math-match",
|
||||
params=LLMAsJudgeScoringFnParams(
|
||||
judge_model="meta-llama/Llama-3.1-405B-Instruct",
|
||||
prompt_template=EQUALITY_TEMPLATE,
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference, UserMessage
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.llm_as_judge_405b_math_match import llm_as_judge_405b_math_match
|
||||
from .fn_defs.llm_as_judge_base import llm_as_judge_base
|
||||
|
||||
|
||||
class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns
|
||||
"""
|
||||
|
||||
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
|
||||
super().__init__(*arg, **kwargs)
|
||||
self.inference_api = inference_api
|
||||
self.supported_fn_defs_registry = {
|
||||
llm_as_judge_base.identifier: llm_as_judge_base,
|
||||
llm_as_judge_405b_math_match.identifier: llm_as_judge_405b_math_match,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
||||
# override params if scoring_params is provided
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
|
||||
assert fn_def.params.prompt_template is not None, "LLM Judge prompt_template not found."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
judge_input_msg = fn_def.params.prompt_template.format(
|
||||
expected_answer=expected_answer,
|
||||
generated_answer=generated_answer,
|
||||
)
|
||||
|
||||
print("judge_input_msg", judge_input_msg)
|
||||
|
||||
judge_response = await self.inference_api.chat_completion(
|
||||
model_id=fn_def.params.judge_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=judge_input_msg,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
score = 1.0 if judge_response.lower().strip() == "yes" else 0.0
|
||||
|
||||
return {
|
||||
"score": score,
|
||||
"judge_feedback": judge_response,
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue