mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
scoring test pass
This commit is contained in:
parent
0351072531
commit
0bce74402f
4 changed files with 32 additions and 10 deletions
|
@ -28,7 +28,7 @@ llm_as_judge_8b_correctness = ScoringFnDef(
|
|||
description="Llm As Judge Scoring Function",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
context=LLMAsJudgeScoringFnParams(
|
||||
params=LLMAsJudgeScoringFnParams(
|
||||
prompt_template=JUDGE_PROMPT,
|
||||
judge_model="Llama3.1-8B-Instruct",
|
||||
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
|
||||
|
|
|
@ -41,26 +41,26 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}."
|
||||
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
|
||||
assert (
|
||||
fn_def.context.prompt_template is not None
|
||||
fn_def.params.prompt_template is not None
|
||||
), "LLM Judge prompt_template not found."
|
||||
assert (
|
||||
fn_def.context.judge_score_regex is not None
|
||||
fn_def.params.judge_score_regex is not None
|
||||
), "LLM Judge judge_score_regex not found."
|
||||
|
||||
input_query = input_row["input_query"]
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
judge_input_msg = fn_def.context.prompt_template.format(
|
||||
judge_input_msg = fn_def.params.prompt_template.format(
|
||||
input_query=input_query,
|
||||
expected_answer=expected_answer,
|
||||
generated_answer=generated_answer,
|
||||
)
|
||||
|
||||
judge_response = await self.inference_api.chat_completion(
|
||||
model=fn_def.context.judge_model,
|
||||
model=fn_def.params.judge_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
@ -69,7 +69,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
rating_regexs = fn_def.context.judge_score_regex
|
||||
rating_regexs = fn_def.params.judge_score_regex
|
||||
|
||||
judge_rating = None
|
||||
for regex in rating_regexs:
|
||||
|
|
|
@ -37,7 +37,6 @@ SCORING_FIXTURES = ["meta_reference", "remote"]
|
|||
@pytest_asyncio.fixture(scope="session")
|
||||
async def scoring_stack(request):
|
||||
fixture_dict = request.param
|
||||
print("!!!", fixture_dict)
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
|
@ -56,5 +55,6 @@ async def scoring_stack(request):
|
|||
return (
|
||||
impls[Api.scoring],
|
||||
impls[Api.scoring_functions],
|
||||
impls[Api.datasetio],
|
||||
impls[Api.datasets],
|
||||
)
|
||||
|
|
|
@ -21,14 +21,36 @@ class TestScoring:
|
|||
async def test_scoring_functions_list(self, scoring_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, scoring_functions_impl, _ = scoring_stack
|
||||
_, scoring_functions_impl, _, _ = scoring_stack
|
||||
response = await scoring_functions_impl.list_scoring_functions()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score(self, scoring_stack):
|
||||
scoring_impl, scoring_functions_impl, datasets_impl = scoring_stack
|
||||
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
|
||||
scoring_stack
|
||||
)
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_functions = [
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
"meta-reference::equality",
|
||||
]
|
||||
response = await scoring_impl.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue