dataset validation

This commit is contained in:
Xi Yan 2024-10-23 12:08:39 -07:00
parent aefa84e70a
commit 7c280e18fb
4 changed files with 63 additions and 69 deletions

View file

@ -63,25 +63,13 @@ def data_url_from_file(file_path: str) -> str:
async def register_dataset(datasets_impl: Datasets):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
dataset = DatasetDefWithProvider(
identifier="test_dataset",
provider_id=os.environ["PROVIDER_ID"],
url=URL(
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
),
dataset_schema={},
)
await datasets_impl.register_dataset(dataset)
async def register_local_dataset(datasets_impl: Datasets):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.jsonl"
test_jsonl_url = data_url_from_file(str(test_file))
dataset = DatasetDefWithProvider(
identifier="test_dataset",
provider_id=os.environ["PROVIDER_ID"],
url=URL(
uri=test_jsonl_url,
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
@ -92,60 +80,60 @@ async def register_local_dataset(datasets_impl: Datasets):
await datasets_impl.register_dataset(dataset)
# @pytest.mark.asyncio
# async def test_datasets_list(datasetio_settings):
# # NOTE: this needs you to ensure that you are starting from a clean state
# # but so far we don't have an unregister API unfortunately, so be careful
# datasets_impl = datasetio_settings["datasets_impl"]
# response = await datasets_impl.list_datasets()
# assert isinstance(response, list)
# assert len(response) == 0
@pytest.mark.asyncio
async def test_datasets_list(datasetio_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
datasets_impl = datasetio_settings["datasets_impl"]
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
# @pytest.mark.asyncio
# async def test_datasets_register(datasetio_settings):
# # NOTE: this needs you to ensure that you are starting from a clean state
# # but so far we don't have an unregister API unfortunately, so be careful
# datasets_impl = datasetio_settings["datasets_impl"]
# await register_dataset(datasets_impl)
@pytest.mark.asyncio
async def test_datasets_register(datasetio_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
datasets_impl = datasetio_settings["datasets_impl"]
await register_dataset(datasets_impl)
# response = await datasets_impl.list_datasets()
# assert isinstance(response, list)
# assert len(response) == 1
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 1
# # register same dataset with same id again will fail
# await register_dataset(datasets_impl)
# response = await datasets_impl.list_datasets()
# assert isinstance(response, list)
# assert len(response) == 1
# assert response[0].identifier == "test_dataset"
# register same dataset with same id again will fail
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 1
assert response[0].identifier == "test_dataset"
# @pytest.mark.asyncio
# async def test_get_rows_paginated(datasetio_settings):
# datasetio_impl = datasetio_settings["datasetio_impl"]
# datasets_impl = datasetio_settings["datasets_impl"]
# await register_dataset(datasets_impl)
@pytest.mark.asyncio
async def test_get_rows_paginated(datasetio_settings):
datasetio_impl = datasetio_settings["datasetio_impl"]
datasets_impl = datasetio_settings["datasets_impl"]
await register_dataset(datasets_impl)
# response = await datasetio_impl.get_rows_paginated(
# dataset_id="test_dataset",
# rows_in_page=3,
# )
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
# assert isinstance(response.rows, list)
# assert len(response.rows) == 3
# assert response.next_page_token == "3"
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
# # iterate over all rows
# response = await datasetio_impl.get_rows_paginated(
# dataset_id="test_dataset",
# rows_in_page=10,
# page_token=response.next_page_token,
# )
# iterate over all rows
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=2,
page_token=response.next_page_token,
)
# assert isinstance(response.rows, list)
# assert len(response.rows) == 10
# assert response.next_page_token == "13"
assert isinstance(response.rows, list)
assert len(response.rows) == 2
assert response.next_page_token == "5"
@pytest.mark.asyncio
@ -153,4 +141,4 @@ async def test_datasets_validation(datasetio_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
datasets_impl = datasetio_settings["datasets_impl"]
await register_local_dataset(datasets_impl)
await register_dataset(datasets_impl)