From afa0c2b14676b17d3cfea606be8c456bb3d173c7 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 23 Oct 2024 22:17:38 -0700 Subject: [PATCH] address comments --- .../impls/meta_reference/datasetio/datasetio.py | 16 +++++++++------- .../impls/meta_reference/scoring/scoring.py | 10 ++++++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index 77ae35538..57ce8e10f 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -61,13 +61,13 @@ class PandasDataframeDataset(BaseDataset): else: return self.df.iloc[idx].to_dict() - def _validate_dataset_schema(self) -> None: - assert self.df is not None, "Dataset not loaded. Please call .load() first" + def _validate_dataset_schema(self, df) -> pandas.DataFrame: # note that we will drop any columns in dataset that are not in the schema - self.df = self.df[self.dataset_def.dataset_schema.keys()] + df = df[self.dataset_def.dataset_schema.keys()] # check all columns in dataset schema are present - assert len(self.df.columns) == len(self.dataset_def.dataset_schema) + assert len(df.columns) == len(self.dataset_def.dataset_schema) # TODO: type checking against column types in dataset schema + return df def load(self) -> None: if self.df is not None: @@ -99,8 +99,7 @@ class PandasDataframeDataset(BaseDataset): else: raise ValueError(f"Unsupported file type: {self.dataset_def.url}") - self.df = df - self._validate_dataset_schema() + self.df = self._validate_dataset_schema(df) class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): @@ -136,7 +135,10 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_info = self.dataset_infos.get(dataset_id) dataset_info.dataset_impl.load() - if page_token is None or not page_token.isnumeric(): + if page_token and not page_token.isnumeric(): + raise ValueError("Invalid page_token") + + if page_token is None: next_page_token = 0 else: next_page_token = int(page_token) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 662d6e0b7..73f9fcc5a 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -59,11 +59,11 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): for required_column in ["generated_answer", "expected_answer", "input_query"]: if required_column not in dataset_def.dataset_schema: raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column. Please make sure '{required_column}' column is in the dataset." + f"Dataset {dataset_id} does not have a '{required_column}' column." ) if dataset_def.dataset_schema[required_column].type != "string": raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'. Please make sure '{required_column}' column is of type 'string'." + f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." ) async def score_batch( @@ -73,12 +73,12 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) - rows_paginated = await self.datasetio_api.get_rows_paginated( + all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, ) res = await self.score( - input_rows=rows_paginated.rows, scoring_functions=scoring_functions + input_rows=all_rows.rows, scoring_functions=scoring_functions ) if save_results_dataset: # TODO: persist and register dataset on to server for reading @@ -94,6 +94,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions: + if scoring_fn_id not in SCORER_REGISTRY: + raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") scorer = SCORER_REGISTRY[scoring_fn_id]() score_results = scorer.score(input_rows) agg_results = scorer.aggregate(score_results)