diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index 3786b5857..7297f5d88 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -62,10 +62,15 @@ class PandasDataframeDataset(BaseDataset): if self.df is None: self.load() - print(self.dataset_def.dataset_schema) - # get columns names - # columns = self.df[self.dataset_def.dataset_schema.keys()] - print(self.df.columns) + assert self.df is not None, "Dataset loading failed. Please check logs." + + self.df = self.df[self.dataset_def.dataset_schema.keys()] + + # check all columns in dataset schema are present + assert len(self.df.columns) == len(self.dataset_def.dataset_schema) + + # check all types match + print(self.df.dtypes) def load(self) -> None: if self.df is not None: diff --git a/llama_stack/providers/tests/datasetio/test_dataset.csv b/llama_stack/providers/tests/datasetio/test_dataset.csv new file mode 100644 index 000000000..a1a250753 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/test_dataset.csv @@ -0,0 +1,6 @@ +input_query,generated_answer,expected_answer +What is the capital of France?,London,Paris +Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg +What is the largest planet in our solar system?,Jupiter,Jupiter +What is the smallest country in the world?,China,Vatican City +What is the currency of Japan?,Yen,Yen diff --git a/llama_stack/providers/tests/datasetio/test_dataset.jsonl b/llama_stack/providers/tests/datasetio/test_dataset.jsonl deleted file mode 100644 index dfdfff458..000000000 --- a/llama_stack/providers/tests/datasetio/test_dataset.jsonl +++ /dev/null @@ -1,5 +0,0 @@ -{"input_query": "What is the capital of France?", "generated_answer": "London", "expected_answer": "Paris"} -{"input_query": "Who is the CEO of Meta?", "generated_answer": "Mark Zuckerberg", "expected_answer": "Mark Zuckerberg"} -{"input_query": "What is the largest planet in our solar system?", "generated_answer": "Jupiter", "expected_answer": "Jupiter"} -{"input_query": "What is the smallest country in the world?", "generated_answer": "China", "expected_answer": "Vatican City"} -{"input_query": "What is the currency of Japan?", "generated_answer": "Yen", "expected_answer": "Yen"} diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 4114207aa..9adef4a58 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -63,25 +63,13 @@ def data_url_from_file(file_path: str) -> str: async def register_dataset(datasets_impl: Datasets): + test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" + test_url = data_url_from_file(str(test_file)) dataset = DatasetDefWithProvider( identifier="test_dataset", provider_id=os.environ["PROVIDER_ID"], url=URL( - uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", - ), - dataset_schema={}, - ) - await datasets_impl.register_dataset(dataset) - - -async def register_local_dataset(datasets_impl: Datasets): - test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.jsonl" - test_jsonl_url = data_url_from_file(str(test_file)) - dataset = DatasetDefWithProvider( - identifier="test_dataset", - provider_id=os.environ["PROVIDER_ID"], - url=URL( - uri=test_jsonl_url, + uri=test_url, ), dataset_schema={ "generated_answer": StringType(), @@ -92,60 +80,60 @@ async def register_local_dataset(datasets_impl: Datasets): await datasets_impl.register_dataset(dataset) -# @pytest.mark.asyncio -# async def test_datasets_list(datasetio_settings): -# # NOTE: this needs you to ensure that you are starting from a clean state -# # but so far we don't have an unregister API unfortunately, so be careful -# datasets_impl = datasetio_settings["datasets_impl"] -# response = await datasets_impl.list_datasets() -# assert isinstance(response, list) -# assert len(response) == 0 +@pytest.mark.asyncio +async def test_datasets_list(datasetio_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + datasets_impl = datasetio_settings["datasets_impl"] + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 0 -# @pytest.mark.asyncio -# async def test_datasets_register(datasetio_settings): -# # NOTE: this needs you to ensure that you are starting from a clean state -# # but so far we don't have an unregister API unfortunately, so be careful -# datasets_impl = datasetio_settings["datasets_impl"] -# await register_dataset(datasets_impl) +@pytest.mark.asyncio +async def test_datasets_register(datasetio_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + datasets_impl = datasetio_settings["datasets_impl"] + await register_dataset(datasets_impl) -# response = await datasets_impl.list_datasets() -# assert isinstance(response, list) -# assert len(response) == 1 + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 1 -# # register same dataset with same id again will fail -# await register_dataset(datasets_impl) -# response = await datasets_impl.list_datasets() -# assert isinstance(response, list) -# assert len(response) == 1 -# assert response[0].identifier == "test_dataset" + # register same dataset with same id again will fail + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 1 + assert response[0].identifier == "test_dataset" -# @pytest.mark.asyncio -# async def test_get_rows_paginated(datasetio_settings): -# datasetio_impl = datasetio_settings["datasetio_impl"] -# datasets_impl = datasetio_settings["datasets_impl"] -# await register_dataset(datasets_impl) +@pytest.mark.asyncio +async def test_get_rows_paginated(datasetio_settings): + datasetio_impl = datasetio_settings["datasetio_impl"] + datasets_impl = datasetio_settings["datasets_impl"] + await register_dataset(datasets_impl) -# response = await datasetio_impl.get_rows_paginated( -# dataset_id="test_dataset", -# rows_in_page=3, -# ) + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) -# assert isinstance(response.rows, list) -# assert len(response.rows) == 3 -# assert response.next_page_token == "3" + assert isinstance(response.rows, list) + assert len(response.rows) == 3 + assert response.next_page_token == "3" -# # iterate over all rows -# response = await datasetio_impl.get_rows_paginated( -# dataset_id="test_dataset", -# rows_in_page=10, -# page_token=response.next_page_token, -# ) + # iterate over all rows + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=2, + page_token=response.next_page_token, + ) -# assert isinstance(response.rows, list) -# assert len(response.rows) == 10 -# assert response.next_page_token == "13" + assert isinstance(response.rows, list) + assert len(response.rows) == 2 + assert response.next_page_token == "5" @pytest.mark.asyncio @@ -153,4 +141,4 @@ async def test_datasets_validation(datasetio_settings): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful datasets_impl = datasetio_settings["datasets_impl"] - await register_local_dataset(datasets_impl) + await register_dataset(datasets_impl)