mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
scoring test pass
This commit is contained in:
parent
0351072531
commit
0bce74402f
4 changed files with 32 additions and 10 deletions
|
@ -28,7 +28,7 @@ llm_as_judge_8b_correctness = ScoringFnDef(
|
|||
description="Llm As Judge Scoring Function",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
context=LLMAsJudgeScoringFnParams(
|
||||
params=LLMAsJudgeScoringFnParams(
|
||||
prompt_template=JUDGE_PROMPT,
|
||||
judge_model="Llama3.1-8B-Instruct",
|
||||
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
|
||||
|
|
|
@ -41,26 +41,26 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}."
|
||||
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
|
||||
assert (
|
||||
fn_def.context.prompt_template is not None
|
||||
fn_def.params.prompt_template is not None
|
||||
), "LLM Judge prompt_template not found."
|
||||
assert (
|
||||
fn_def.context.judge_score_regex is not None
|
||||
fn_def.params.judge_score_regex is not None
|
||||
), "LLM Judge judge_score_regex not found."
|
||||
|
||||
input_query = input_row["input_query"]
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
judge_input_msg = fn_def.context.prompt_template.format(
|
||||
judge_input_msg = fn_def.params.prompt_template.format(
|
||||
input_query=input_query,
|
||||
expected_answer=expected_answer,
|
||||
generated_answer=generated_answer,
|
||||
)
|
||||
|
||||
judge_response = await self.inference_api.chat_completion(
|
||||
model=fn_def.context.judge_model,
|
||||
model=fn_def.params.judge_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
@ -69,7 +69,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
rating_regexs = fn_def.context.judge_score_regex
|
||||
rating_regexs = fn_def.params.judge_score_regex
|
||||
|
||||
judge_rating = None
|
||||
for regex in rating_regexs:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue