scoring test pass

This commit is contained in:
Xi Yan 2024-11-06 17:27:55 -08:00
parent 0351072531
commit 0bce74402f
4 changed files with 32 additions and 10 deletions

View file

@ -28,7 +28,7 @@ llm_as_judge_8b_correctness = ScoringFnDef(
description="Llm As Judge Scoring Function",
parameters=[],
return_type=NumberType(),
context=LLMAsJudgeScoringFnParams(
params=LLMAsJudgeScoringFnParams(
prompt_template=JUDGE_PROMPT,
judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],

View file

@ -41,26 +41,26 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}."
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert (
fn_def.context.prompt_template is not None
fn_def.params.prompt_template is not None
), "LLM Judge prompt_template not found."
assert (
fn_def.context.judge_score_regex is not None
fn_def.params.judge_score_regex is not None
), "LLM Judge judge_score_regex not found."
input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
judge_input_msg = fn_def.context.prompt_template.format(
judge_input_msg = fn_def.params.prompt_template.format(
input_query=input_query,
expected_answer=expected_answer,
generated_answer=generated_answer,
)
judge_response = await self.inference_api.chat_completion(
model=fn_def.context.judge_model,
model=fn_def.params.judge_model,
messages=[
{
"role": "user",
@ -69,7 +69,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
],
)
content = judge_response.completion_message.content
rating_regexs = fn_def.context.judge_score_regex
rating_regexs = fn_def.params.judge_score_regex
judge_rating = None
for regex in rating_regexs: