diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py index 60ce74477..09c7badd5 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -28,7 +28,7 @@ llm_as_judge_8b_correctness = ScoringFnDef( description="Llm As Judge Scoring Function", parameters=[], return_type=NumberType(), - context=LLMAsJudgeScoringFnParams( + params=LLMAsJudgeScoringFnParams( prompt_template=JUDGE_PROMPT, judge_model="Llama3.1-8B-Instruct", judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"], diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index 84dd28fd7..3c594eac8 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -41,26 +41,26 @@ class LlmAsJudgeScoringFn(BaseScoringFn): scoring_fn_identifier is not None ), "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] - assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}." + assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}." assert ( - fn_def.context.prompt_template is not None + fn_def.params.prompt_template is not None ), "LLM Judge prompt_template not found." assert ( - fn_def.context.judge_score_regex is not None + fn_def.params.judge_score_regex is not None ), "LLM Judge judge_score_regex not found." input_query = input_row["input_query"] expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] - judge_input_msg = fn_def.context.prompt_template.format( + judge_input_msg = fn_def.params.prompt_template.format( input_query=input_query, expected_answer=expected_answer, generated_answer=generated_answer, ) judge_response = await self.inference_api.chat_completion( - model=fn_def.context.judge_model, + model=fn_def.params.judge_model, messages=[ { "role": "user", @@ -69,7 +69,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn): ], ) content = judge_response.completion_message.content - rating_regexs = fn_def.context.judge_score_regex + rating_regexs = fn_def.params.judge_score_regex judge_rating = None for regex in rating_regexs: diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 337838d8e..925f98779 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -37,7 +37,6 @@ SCORING_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") async def scoring_stack(request): fixture_dict = request.param - print("!!!", fixture_dict) providers = {} provider_data = {} @@ -56,5 +55,6 @@ async def scoring_stack(request): return ( impls[Api.scoring], impls[Api.scoring_functions], + impls[Api.datasetio], impls[Api.datasets], ) diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index d1518d2e3..503ccedb3 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -21,14 +21,36 @@ class TestScoring: async def test_scoring_functions_list(self, scoring_stack): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful - _, scoring_functions_impl, _ = scoring_stack + _, scoring_functions_impl, _, _ = scoring_stack response = await scoring_functions_impl.list_scoring_functions() assert isinstance(response, list) assert len(response) > 0 @pytest.mark.asyncio async def test_scoring_score(self, scoring_stack): - scoring_impl, scoring_functions_impl, datasets_impl = scoring_stack + scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( + scoring_stack + ) await register_dataset(datasets_impl) response = await datasets_impl.list_datasets() assert len(response) == 1 + + # scoring individual rows + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert len(rows.rows) == 3 + + scoring_functions = [ + "meta-reference::llm_as_judge_8b_correctness", + "meta-reference::equality", + ] + response = await scoring_impl.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows)