dataset validation

This commit is contained in:
Xi Yan 2024-10-23 12:08:39 -07:00
parent aefa84e70a
commit 7c280e18fb
4 changed files with 63 additions and 69 deletions

View file

@ -62,10 +62,15 @@ class PandasDataframeDataset(BaseDataset):
if self.df is None:
self.load()
print(self.dataset_def.dataset_schema)
# get columns names
# columns = self.df[self.dataset_def.dataset_schema.keys()]
print(self.df.columns)
assert self.df is not None, "Dataset loading failed. Please check logs."
self.df = self.df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(self.df.columns) == len(self.dataset_def.dataset_schema)
# check all types match
print(self.df.dtypes)
def load(self) -> None:
if self.df is not None:

View file

@ -0,0 +1,6 @@
input_query,generated_answer,expected_answer
What is the capital of France?,London,Paris
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
What is the largest planet in our solar system?,Jupiter,Jupiter
What is the smallest country in the world?,China,Vatican City
What is the currency of Japan?,Yen,Yen
1 input_query generated_answer expected_answer
2 What is the capital of France? London Paris
3 Who is the CEO of Meta? Mark Zuckerberg Mark Zuckerberg
4 What is the largest planet in our solar system? Jupiter Jupiter
5 What is the smallest country in the world? China Vatican City
6 What is the currency of Japan? Yen Yen

View file

@ -1,5 +0,0 @@
{"input_query": "What is the capital of France?", "generated_answer": "London", "expected_answer": "Paris"}
{"input_query": "Who is the CEO of Meta?", "generated_answer": "Mark Zuckerberg", "expected_answer": "Mark Zuckerberg"}
{"input_query": "What is the largest planet in our solar system?", "generated_answer": "Jupiter", "expected_answer": "Jupiter"}
{"input_query": "What is the smallest country in the world?", "generated_answer": "China", "expected_answer": "Vatican City"}
{"input_query": "What is the currency of Japan?", "generated_answer": "Yen", "expected_answer": "Yen"}

View file

@ -63,25 +63,13 @@ def data_url_from_file(file_path: str) -> str:
async def register_dataset(datasets_impl: Datasets):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
dataset = DatasetDefWithProvider(
identifier="test_dataset",
provider_id=os.environ["PROVIDER_ID"],
url=URL(
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
),
dataset_schema={},
)
await datasets_impl.register_dataset(dataset)
async def register_local_dataset(datasets_impl: Datasets):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.jsonl"
test_jsonl_url = data_url_from_file(str(test_file))
dataset = DatasetDefWithProvider(
identifier="test_dataset",
provider_id=os.environ["PROVIDER_ID"],
url=URL(
uri=test_jsonl_url,
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
@ -92,60 +80,60 @@ async def register_local_dataset(datasets_impl: Datasets):
await datasets_impl.register_dataset(dataset)
# @pytest.mark.asyncio
# async def test_datasets_list(datasetio_settings):
# # NOTE: this needs you to ensure that you are starting from a clean state
# # but so far we don't have an unregister API unfortunately, so be careful
# datasets_impl = datasetio_settings["datasets_impl"]
# response = await datasets_impl.list_datasets()
# assert isinstance(response, list)
# assert len(response) == 0
@pytest.mark.asyncio
async def test_datasets_list(datasetio_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
datasets_impl = datasetio_settings["datasets_impl"]
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
# @pytest.mark.asyncio
# async def test_datasets_register(datasetio_settings):
# # NOTE: this needs you to ensure that you are starting from a clean state
# # but so far we don't have an unregister API unfortunately, so be careful
# datasets_impl = datasetio_settings["datasets_impl"]
# await register_dataset(datasets_impl)
@pytest.mark.asyncio
async def test_datasets_register(datasetio_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
datasets_impl = datasetio_settings["datasets_impl"]
await register_dataset(datasets_impl)
# response = await datasets_impl.list_datasets()
# assert isinstance(response, list)
# assert len(response) == 1
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 1
# # register same dataset with same id again will fail
# await register_dataset(datasets_impl)
# response = await datasets_impl.list_datasets()
# assert isinstance(response, list)
# assert len(response) == 1
# assert response[0].identifier == "test_dataset"
# register same dataset with same id again will fail
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 1
assert response[0].identifier == "test_dataset"
# @pytest.mark.asyncio
# async def test_get_rows_paginated(datasetio_settings):
# datasetio_impl = datasetio_settings["datasetio_impl"]
# datasets_impl = datasetio_settings["datasets_impl"]
# await register_dataset(datasets_impl)
@pytest.mark.asyncio
async def test_get_rows_paginated(datasetio_settings):
datasetio_impl = datasetio_settings["datasetio_impl"]
datasets_impl = datasetio_settings["datasets_impl"]
await register_dataset(datasets_impl)
# response = await datasetio_impl.get_rows_paginated(
# dataset_id="test_dataset",
# rows_in_page=3,
# )
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
# assert isinstance(response.rows, list)
# assert len(response.rows) == 3
# assert response.next_page_token == "3"
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
# # iterate over all rows
# response = await datasetio_impl.get_rows_paginated(
# dataset_id="test_dataset",
# rows_in_page=10,
# page_token=response.next_page_token,
# )
# iterate over all rows
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=2,
page_token=response.next_page_token,
)
# assert isinstance(response.rows, list)
# assert len(response.rows) == 10
# assert response.next_page_token == "13"
assert isinstance(response.rows, list)
assert len(response.rows) == 2
assert response.next_page_token == "5"
@pytest.mark.asyncio
@ -153,4 +141,4 @@ async def test_datasets_validation(datasetio_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
datasets_impl = datasetio_settings["datasets_impl"]
await register_local_dataset(datasets_impl)
await register_dataset(datasets_impl)