From 56239fce908ca0e580f83f5a90e4eb0e61ebb811 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 6 Nov 2024 18:07:16 -0800 Subject: [PATCH] scoring fix --- llama_stack/apis/scoring/scoring.py | 6 ++- .../scoring_functions/scoring_functions.py | 5 +- llama_stack/distribution/routers/routers.py | 8 ++- .../inline/meta_reference/scoring/scoring.py | 20 ++++++-- .../scoring/scoring_fn/base_scoring_fn.py | 8 ++- .../scoring/scoring_fn/equality_scoring_fn.py | 1 + .../fn_defs/llm_as_judge_8b_correctness.py | 6 ++- .../scoring_fn/llm_as_judge_scoring_fn.py | 14 ++++-- .../scoring_fn/subset_of_scoring_fn.py | 1 + .../providers/tests/scoring/test_scoring.py | 50 +++++++++++++++++++ 10 files changed, 104 insertions(+), 15 deletions(-) diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 1fd523dcb..a518bd806 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -49,10 +49,14 @@ class Scoring(Protocol): self, dataset_id: str, scoring_functions: List[str], + scoring_params: Optional[Dict[str, ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @webmethod(route="/scoring/score") async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + scoring_params: Optional[Dict[str, ScoringFnParams]] = None, ) -> ScoreResponse: ... diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index b2f28e374..341f84c36 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -38,7 +38,10 @@ class LLMAsJudgeScoringFnParams(BaseModel): ) judge_model: str prompt_template: Optional[str] = None - judge_score_regex: Optional[List[str]] = Field() + judge_score_regexes: Optional[List[str]] = Field( + description="Regexes to extract the answer from generated response", + default_factory=list, + ) @json_schema_type diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 760dbaf2f..5b8274245 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -212,6 +212,7 @@ class ScoringRouter(Scoring): self, dataset_id: str, scoring_functions: List[str], + scoring_params: Optional[Dict[str, ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: res = {} @@ -221,6 +222,7 @@ class ScoringRouter(Scoring): ).score_batch( dataset_id=dataset_id, scoring_functions=[fn_identifier], + scoring_params=scoring_params, ) res.update(score_response.results) @@ -232,7 +234,10 @@ class ScoringRouter(Scoring): ) async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + scoring_params: Optional[Dict[str, ScoringFnParams]] = None, ) -> ScoreResponse: res = {} # look up and map each scoring function to its provider impl @@ -242,6 +247,7 @@ class ScoringRouter(Scoring): ).score( input_rows=input_rows, scoring_functions=[fn_identifier], + scoring_params=scoring_params, ) res.update(score_response.results) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py index 709b2f0c6..56cf13503 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -74,8 +74,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): return scoring_fn_defs_list async def register_scoring_function(self, function_def: ScoringFnDef) -> None: - self.llm_as_judge_fn.register_scoring_fn_def(function_def) - self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn + raise NotImplementedError("Register scoring function not implemented yet") async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) @@ -98,6 +97,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): self, dataset_id: str, scoring_functions: List[str], + scoring_params: Optional[Dict[str, ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) @@ -106,7 +106,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): rows_in_page=-1, ) res = await self.score( - input_rows=all_rows.rows, scoring_functions=scoring_functions + input_rows=all_rows.rows, + scoring_functions=scoring_functions, + scoring_params=scoring_params, ) if save_results_dataset: # TODO: persist and register dataset on to server for reading @@ -118,14 +120,22 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + scoring_params: Optional[Dict[str, ScoringFnParams]] = None, ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions: if scoring_fn_id not in self.scoring_fn_id_impls: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] - score_results = await scoring_fn.score(input_rows, scoring_fn_id) + scoring_fn_params = None + if scoring_params is not None: + scoring_fn_params = scoring_params.get(scoring_fn_id, None) + score_results = await scoring_fn.score( + input_rows, scoring_fn_id, scoring_fn_params + ) agg_results = await scoring_fn.aggregate(score_results) res[scoring_fn_id] = ScoringResult( score_rows=score_results, diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py index cbd875be6..532686ebd 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py @@ -36,7 +36,10 @@ class BaseScoringFn(ABC): @abstractmethod async def score_row( - self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: raise NotImplementedError() @@ -50,8 +53,9 @@ class BaseScoringFn(ABC): self, input_rows: List[Dict[str, Any]], scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> List[ScoringResultRow]: return [ - await self.score_row(input_row, scoring_fn_identifier) + await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows ] diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index 2a0cd0578..07405d56c 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -35,6 +35,7 @@ class EqualityScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = "equality", + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert ( diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py index 09c7badd5..cfef52160 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -31,6 +31,10 @@ llm_as_judge_8b_correctness = ScoringFnDef( params=LLMAsJudgeScoringFnParams( prompt_template=JUDGE_PROMPT, judge_model="Llama3.1-8B-Instruct", - judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"], + judge_score_regexes=[ + r"Total rating: (\d+)", + r"rating: (\d+)", + r"Rating: (\d+)", + ], ), ) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index 3c594eac8..f98f7fb5e 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -36,18 +36,24 @@ class LlmAsJudgeScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: assert ( scoring_fn_identifier is not None ), "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + + # override params if scoring_params is provided + if scoring_params is not None: + fn_def.params = scoring_params + assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}." assert ( fn_def.params.prompt_template is not None ), "LLM Judge prompt_template not found." assert ( - fn_def.params.judge_score_regex is not None - ), "LLM Judge judge_score_regex not found." + fn_def.params.judge_score_regexes is not None + ), "LLM Judge judge_score_regexes not found." input_query = input_row["input_query"] expected_answer = input_row["expected_answer"] @@ -69,10 +75,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn): ], ) content = judge_response.completion_message.content - rating_regexs = fn_def.params.judge_score_regex + rating_regexes = fn_def.params.judge_score_regexes judge_rating = None - for regex in rating_regexs: + for regex in rating_regexes: match = re.search(regex, content) if match: judge_rating = int(match.group(1)) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index f42964c1f..289c63dd7 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -34,6 +34,7 @@ class SubsetOfScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = "subset_of", + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 2c6ff1641..7a36a96f9 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -7,6 +7,8 @@ import pytest +from llama_stack.apis.scoring_functions import * # noqa: F403 + from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset # How to run this test: @@ -64,3 +66,51 @@ class TestScoring: for x in scoring_functions: assert x in response.results assert len(response.results[x].score_rows) == 5 + + @pytest.mark.asyncio + async def test_scoring_score_with_params(self, scoring_stack): + scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( + scoring_stack + ) + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert len(response) == 1 + + # scoring individual rows + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert len(rows.rows) == 3 + + params = { + "meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams( + judge_model="Llama3.1-405B-Instruct", + prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", + judge_score_regexes=[r"Score: (\d+)"], + ) + } + + scoring_functions = [ + "meta-reference::llm_as_judge_8b_correctness", + ] + response = await scoring_impl.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + scoring_params=params, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows) + + # score batch + response = await scoring_impl.score_batch( + dataset_id="test_dataset", + scoring_functions=scoring_functions, + scoring_params=params, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == 5