dataset validation

This commit is contained in:
Xi Yan 2024-10-23 12:08:39 -07:00
parent aefa84e70a
commit 7c280e18fb
4 changed files with 63 additions and 69 deletions

View file

@ -62,10 +62,15 @@ class PandasDataframeDataset(BaseDataset):
if self.df is None: if self.df is None:
self.load() self.load()
print(self.dataset_def.dataset_schema) assert self.df is not None, "Dataset loading failed. Please check logs."
# get columns names
# columns = self.df[self.dataset_def.dataset_schema.keys()] self.df = self.df[self.dataset_def.dataset_schema.keys()]
print(self.df.columns)
# check all columns in dataset schema are present
assert len(self.df.columns) == len(self.dataset_def.dataset_schema)
# check all types match
print(self.df.dtypes)
def load(self) -> None: def load(self) -> None:
if self.df is not None: if self.df is not None:

View file

@ -0,0 +1,6 @@
input_query,generated_answer,expected_answer
What is the capital of France?,London,Paris
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
What is the largest planet in our solar system?,Jupiter,Jupiter
What is the smallest country in the world?,China,Vatican City
What is the currency of Japan?,Yen,Yen
1 input_query generated_answer expected_answer
2 What is the capital of France? London Paris
3 Who is the CEO of Meta? Mark Zuckerberg Mark Zuckerberg
4 What is the largest planet in our solar system? Jupiter Jupiter
5 What is the smallest country in the world? China Vatican City
6 What is the currency of Japan? Yen Yen

View file

@ -1,5 +0,0 @@
{"input_query": "What is the capital of France?", "generated_answer": "London", "expected_answer": "Paris"}
{"input_query": "Who is the CEO of Meta?", "generated_answer": "Mark Zuckerberg", "expected_answer": "Mark Zuckerberg"}
{"input_query": "What is the largest planet in our solar system?", "generated_answer": "Jupiter", "expected_answer": "Jupiter"}
{"input_query": "What is the smallest country in the world?", "generated_answer": "China", "expected_answer": "Vatican City"}
{"input_query": "What is the currency of Japan?", "generated_answer": "Yen", "expected_answer": "Yen"}

View file

@ -63,25 +63,13 @@ def data_url_from_file(file_path: str) -> str:
async def register_dataset(datasets_impl: Datasets): async def register_dataset(datasets_impl: Datasets):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
dataset = DatasetDefWithProvider( dataset = DatasetDefWithProvider(
identifier="test_dataset", identifier="test_dataset",
provider_id=os.environ["PROVIDER_ID"], provider_id=os.environ["PROVIDER_ID"],
url=URL( url=URL(
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", uri=test_url,
),
dataset_schema={},
)
await datasets_impl.register_dataset(dataset)
async def register_local_dataset(datasets_impl: Datasets):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.jsonl"
test_jsonl_url = data_url_from_file(str(test_file))
dataset = DatasetDefWithProvider(
identifier="test_dataset",
provider_id=os.environ["PROVIDER_ID"],
url=URL(
uri=test_jsonl_url,
), ),
dataset_schema={ dataset_schema={
"generated_answer": StringType(), "generated_answer": StringType(),
@ -92,60 +80,60 @@ async def register_local_dataset(datasets_impl: Datasets):
await datasets_impl.register_dataset(dataset) await datasets_impl.register_dataset(dataset)
# @pytest.mark.asyncio @pytest.mark.asyncio
# async def test_datasets_list(datasetio_settings): async def test_datasets_list(datasetio_settings):
# # NOTE: this needs you to ensure that you are starting from a clean state # NOTE: this needs you to ensure that you are starting from a clean state
# # but so far we don't have an unregister API unfortunately, so be careful # but so far we don't have an unregister API unfortunately, so be careful
# datasets_impl = datasetio_settings["datasets_impl"] datasets_impl = datasetio_settings["datasets_impl"]
# response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
# assert isinstance(response, list) assert isinstance(response, list)
# assert len(response) == 0 assert len(response) == 0
# @pytest.mark.asyncio @pytest.mark.asyncio
# async def test_datasets_register(datasetio_settings): async def test_datasets_register(datasetio_settings):
# # NOTE: this needs you to ensure that you are starting from a clean state # NOTE: this needs you to ensure that you are starting from a clean state
# # but so far we don't have an unregister API unfortunately, so be careful # but so far we don't have an unregister API unfortunately, so be careful
# datasets_impl = datasetio_settings["datasets_impl"] datasets_impl = datasetio_settings["datasets_impl"]
# await register_dataset(datasets_impl) await register_dataset(datasets_impl)
# response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
# assert isinstance(response, list) assert isinstance(response, list)
# assert len(response) == 1 assert len(response) == 1
# # register same dataset with same id again will fail # register same dataset with same id again will fail
# await register_dataset(datasets_impl) await register_dataset(datasets_impl)
# response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
# assert isinstance(response, list) assert isinstance(response, list)
# assert len(response) == 1 assert len(response) == 1
# assert response[0].identifier == "test_dataset" assert response[0].identifier == "test_dataset"
# @pytest.mark.asyncio @pytest.mark.asyncio
# async def test_get_rows_paginated(datasetio_settings): async def test_get_rows_paginated(datasetio_settings):
# datasetio_impl = datasetio_settings["datasetio_impl"] datasetio_impl = datasetio_settings["datasetio_impl"]
# datasets_impl = datasetio_settings["datasets_impl"] datasets_impl = datasetio_settings["datasets_impl"]
# await register_dataset(datasets_impl) await register_dataset(datasets_impl)
# response = await datasetio_impl.get_rows_paginated( response = await datasetio_impl.get_rows_paginated(
# dataset_id="test_dataset", dataset_id="test_dataset",
# rows_in_page=3, rows_in_page=3,
# ) )
# assert isinstance(response.rows, list) assert isinstance(response.rows, list)
# assert len(response.rows) == 3 assert len(response.rows) == 3
# assert response.next_page_token == "3" assert response.next_page_token == "3"
# # iterate over all rows # iterate over all rows
# response = await datasetio_impl.get_rows_paginated( response = await datasetio_impl.get_rows_paginated(
# dataset_id="test_dataset", dataset_id="test_dataset",
# rows_in_page=10, rows_in_page=2,
# page_token=response.next_page_token, page_token=response.next_page_token,
# ) )
# assert isinstance(response.rows, list) assert isinstance(response.rows, list)
# assert len(response.rows) == 10 assert len(response.rows) == 2
# assert response.next_page_token == "13" assert response.next_page_token == "5"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -153,4 +141,4 @@ async def test_datasets_validation(datasetio_settings):
# NOTE: this needs you to ensure that you are starting from a clean state # NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful # but so far we don't have an unregister API unfortunately, so be careful
datasets_impl = datasetio_settings["datasets_impl"] datasets_impl = datasetio_settings["datasets_impl"]
await register_local_dataset(datasets_impl) await register_dataset(datasets_impl)