From 59c93548bc01f2a0d5694ac6fd294eea7180d24b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 23 Oct 2024 17:43:41 -0700 Subject: [PATCH] validate scorer input --- .../scoring/scorer/equality_scorer.py | 5 +++++ .../impls/meta_reference/scoring/scoring.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py index 9fd3716b7..e79a96068 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py @@ -25,6 +25,11 @@ class EqualityScorer(BaseScorer): ) def score_row(self, input_row: Dict[str, Any]) -> ScoringResult: + assert "expected_answer" in input_row, "Expected answer not found in input row." + assert ( + "generated_answer" in input_row + ), "Generated answer not found in input row." + expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] score = 1.0 if expected_answer == generated_answer else 0.0 diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 8b07bcbd8..662d6e0b7 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -49,12 +49,30 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): "Dynamically registering scoring functions is not supported" ) + async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: + dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError( + f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." + ) + + for required_column in ["generated_answer", "expected_answer", "input_query"]: + if required_column not in dataset_def.dataset_schema: + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column. Please make sure '{required_column}' column is in the dataset." + ) + if dataset_def.dataset_schema[required_column].type != "string": + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'. Please make sure '{required_column}' column is of type 'string'." + ) + async def score_batch( self, dataset_id: str, scoring_functions: List[str], save_results_dataset: bool = False, ) -> ScoreBatchResponse: + await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) rows_paginated = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1,