mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
scoring test pass
This commit is contained in:
parent
0351072531
commit
0bce74402f
4 changed files with 32 additions and 10 deletions
|
@ -28,7 +28,7 @@ llm_as_judge_8b_correctness = ScoringFnDef(
|
||||||
description="Llm As Judge Scoring Function",
|
description="Llm As Judge Scoring Function",
|
||||||
parameters=[],
|
parameters=[],
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
context=LLMAsJudgeScoringFnParams(
|
params=LLMAsJudgeScoringFnParams(
|
||||||
prompt_template=JUDGE_PROMPT,
|
prompt_template=JUDGE_PROMPT,
|
||||||
judge_model="Llama3.1-8B-Instruct",
|
judge_model="Llama3.1-8B-Instruct",
|
||||||
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
|
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
|
||||||
|
|
|
@ -41,26 +41,26 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
scoring_fn_identifier is not None
|
scoring_fn_identifier is not None
|
||||||
), "Scoring function identifier not found."
|
), "Scoring function identifier not found."
|
||||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}."
|
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
|
||||||
assert (
|
assert (
|
||||||
fn_def.context.prompt_template is not None
|
fn_def.params.prompt_template is not None
|
||||||
), "LLM Judge prompt_template not found."
|
), "LLM Judge prompt_template not found."
|
||||||
assert (
|
assert (
|
||||||
fn_def.context.judge_score_regex is not None
|
fn_def.params.judge_score_regex is not None
|
||||||
), "LLM Judge judge_score_regex not found."
|
), "LLM Judge judge_score_regex not found."
|
||||||
|
|
||||||
input_query = input_row["input_query"]
|
input_query = input_row["input_query"]
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
||||||
judge_input_msg = fn_def.context.prompt_template.format(
|
judge_input_msg = fn_def.params.prompt_template.format(
|
||||||
input_query=input_query,
|
input_query=input_query,
|
||||||
expected_answer=expected_answer,
|
expected_answer=expected_answer,
|
||||||
generated_answer=generated_answer,
|
generated_answer=generated_answer,
|
||||||
)
|
)
|
||||||
|
|
||||||
judge_response = await self.inference_api.chat_completion(
|
judge_response = await self.inference_api.chat_completion(
|
||||||
model=fn_def.context.judge_model,
|
model=fn_def.params.judge_model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -69,7 +69,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
content = judge_response.completion_message.content
|
content = judge_response.completion_message.content
|
||||||
rating_regexs = fn_def.context.judge_score_regex
|
rating_regexs = fn_def.params.judge_score_regex
|
||||||
|
|
||||||
judge_rating = None
|
judge_rating = None
|
||||||
for regex in rating_regexs:
|
for regex in rating_regexs:
|
||||||
|
|
|
@ -37,7 +37,6 @@ SCORING_FIXTURES = ["meta_reference", "remote"]
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def scoring_stack(request):
|
async def scoring_stack(request):
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
print("!!!", fixture_dict)
|
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
|
@ -56,5 +55,6 @@ async def scoring_stack(request):
|
||||||
return (
|
return (
|
||||||
impls[Api.scoring],
|
impls[Api.scoring],
|
||||||
impls[Api.scoring_functions],
|
impls[Api.scoring_functions],
|
||||||
|
impls[Api.datasetio],
|
||||||
impls[Api.datasets],
|
impls[Api.datasets],
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,14 +21,36 @@ class TestScoring:
|
||||||
async def test_scoring_functions_list(self, scoring_stack):
|
async def test_scoring_functions_list(self, scoring_stack):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
_, scoring_functions_impl, _ = scoring_stack
|
_, scoring_functions_impl, _, _ = scoring_stack
|
||||||
response = await scoring_functions_impl.list_scoring_functions()
|
response = await scoring_functions_impl.list_scoring_functions()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scoring_score(self, scoring_stack):
|
async def test_scoring_score(self, scoring_stack):
|
||||||
scoring_impl, scoring_functions_impl, datasets_impl = scoring_stack
|
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
|
||||||
|
scoring_stack
|
||||||
|
)
|
||||||
await register_dataset(datasets_impl)
|
await register_dataset(datasets_impl)
|
||||||
response = await datasets_impl.list_datasets()
|
response = await datasets_impl.list_datasets()
|
||||||
assert len(response) == 1
|
assert len(response) == 1
|
||||||
|
|
||||||
|
# scoring individual rows
|
||||||
|
rows = await datasetio_impl.get_rows_paginated(
|
||||||
|
dataset_id="test_dataset",
|
||||||
|
rows_in_page=3,
|
||||||
|
)
|
||||||
|
assert len(rows.rows) == 3
|
||||||
|
|
||||||
|
scoring_functions = [
|
||||||
|
"meta-reference::llm_as_judge_8b_correctness",
|
||||||
|
"meta-reference::equality",
|
||||||
|
]
|
||||||
|
response = await scoring_impl.score(
|
||||||
|
input_rows=rows.rows,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
)
|
||||||
|
assert len(response.results) == len(scoring_functions)
|
||||||
|
for x in scoring_functions:
|
||||||
|
assert x in response.results
|
||||||
|
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue