mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 23:29:43 +00:00
dataset validation
This commit is contained in:
parent
aefa84e70a
commit
7c280e18fb
4 changed files with 63 additions and 69 deletions
|
@ -62,10 +62,15 @@ class PandasDataframeDataset(BaseDataset):
|
||||||
if self.df is None:
|
if self.df is None:
|
||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
print(self.dataset_def.dataset_schema)
|
assert self.df is not None, "Dataset loading failed. Please check logs."
|
||||||
# get columns names
|
|
||||||
# columns = self.df[self.dataset_def.dataset_schema.keys()]
|
self.df = self.df[self.dataset_def.dataset_schema.keys()]
|
||||||
print(self.df.columns)
|
|
||||||
|
# check all columns in dataset schema are present
|
||||||
|
assert len(self.df.columns) == len(self.dataset_def.dataset_schema)
|
||||||
|
|
||||||
|
# check all types match
|
||||||
|
print(self.df.dtypes)
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
if self.df is not None:
|
if self.df is not None:
|
||||||
|
|
6
llama_stack/providers/tests/datasetio/test_dataset.csv
Normal file
6
llama_stack/providers/tests/datasetio/test_dataset.csv
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
input_query,generated_answer,expected_answer
|
||||||
|
What is the capital of France?,London,Paris
|
||||||
|
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
|
||||||
|
What is the largest planet in our solar system?,Jupiter,Jupiter
|
||||||
|
What is the smallest country in the world?,China,Vatican City
|
||||||
|
What is the currency of Japan?,Yen,Yen
|
|
|
@ -1,5 +0,0 @@
|
||||||
{"input_query": "What is the capital of France?", "generated_answer": "London", "expected_answer": "Paris"}
|
|
||||||
{"input_query": "Who is the CEO of Meta?", "generated_answer": "Mark Zuckerberg", "expected_answer": "Mark Zuckerberg"}
|
|
||||||
{"input_query": "What is the largest planet in our solar system?", "generated_answer": "Jupiter", "expected_answer": "Jupiter"}
|
|
||||||
{"input_query": "What is the smallest country in the world?", "generated_answer": "China", "expected_answer": "Vatican City"}
|
|
||||||
{"input_query": "What is the currency of Japan?", "generated_answer": "Yen", "expected_answer": "Yen"}
|
|
|
@ -63,25 +63,13 @@ def data_url_from_file(file_path: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
async def register_dataset(datasets_impl: Datasets):
|
async def register_dataset(datasets_impl: Datasets):
|
||||||
|
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||||
|
test_url = data_url_from_file(str(test_file))
|
||||||
dataset = DatasetDefWithProvider(
|
dataset = DatasetDefWithProvider(
|
||||||
identifier="test_dataset",
|
identifier="test_dataset",
|
||||||
provider_id=os.environ["PROVIDER_ID"],
|
provider_id=os.environ["PROVIDER_ID"],
|
||||||
url=URL(
|
url=URL(
|
||||||
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
uri=test_url,
|
||||||
),
|
|
||||||
dataset_schema={},
|
|
||||||
)
|
|
||||||
await datasets_impl.register_dataset(dataset)
|
|
||||||
|
|
||||||
|
|
||||||
async def register_local_dataset(datasets_impl: Datasets):
|
|
||||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.jsonl"
|
|
||||||
test_jsonl_url = data_url_from_file(str(test_file))
|
|
||||||
dataset = DatasetDefWithProvider(
|
|
||||||
identifier="test_dataset",
|
|
||||||
provider_id=os.environ["PROVIDER_ID"],
|
|
||||||
url=URL(
|
|
||||||
uri=test_jsonl_url,
|
|
||||||
),
|
),
|
||||||
dataset_schema={
|
dataset_schema={
|
||||||
"generated_answer": StringType(),
|
"generated_answer": StringType(),
|
||||||
|
@ -92,60 +80,60 @@ async def register_local_dataset(datasets_impl: Datasets):
|
||||||
await datasets_impl.register_dataset(dataset)
|
await datasets_impl.register_dataset(dataset)
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
# async def test_datasets_list(datasetio_settings):
|
async def test_datasets_list(datasetio_settings):
|
||||||
# # NOTE: this needs you to ensure that you are starting from a clean state
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
# # but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
# datasets_impl = datasetio_settings["datasets_impl"]
|
datasets_impl = datasetio_settings["datasets_impl"]
|
||||||
# response = await datasets_impl.list_datasets()
|
response = await datasets_impl.list_datasets()
|
||||||
# assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
# assert len(response) == 0
|
assert len(response) == 0
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
# async def test_datasets_register(datasetio_settings):
|
async def test_datasets_register(datasetio_settings):
|
||||||
# # NOTE: this needs you to ensure that you are starting from a clean state
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
# # but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
# datasets_impl = datasetio_settings["datasets_impl"]
|
datasets_impl = datasetio_settings["datasets_impl"]
|
||||||
# await register_dataset(datasets_impl)
|
await register_dataset(datasets_impl)
|
||||||
|
|
||||||
# response = await datasets_impl.list_datasets()
|
response = await datasets_impl.list_datasets()
|
||||||
# assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
# assert len(response) == 1
|
assert len(response) == 1
|
||||||
|
|
||||||
# # register same dataset with same id again will fail
|
# register same dataset with same id again will fail
|
||||||
# await register_dataset(datasets_impl)
|
await register_dataset(datasets_impl)
|
||||||
# response = await datasets_impl.list_datasets()
|
response = await datasets_impl.list_datasets()
|
||||||
# assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
# assert len(response) == 1
|
assert len(response) == 1
|
||||||
# assert response[0].identifier == "test_dataset"
|
assert response[0].identifier == "test_dataset"
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
# async def test_get_rows_paginated(datasetio_settings):
|
async def test_get_rows_paginated(datasetio_settings):
|
||||||
# datasetio_impl = datasetio_settings["datasetio_impl"]
|
datasetio_impl = datasetio_settings["datasetio_impl"]
|
||||||
# datasets_impl = datasetio_settings["datasets_impl"]
|
datasets_impl = datasetio_settings["datasets_impl"]
|
||||||
# await register_dataset(datasets_impl)
|
await register_dataset(datasets_impl)
|
||||||
|
|
||||||
# response = await datasetio_impl.get_rows_paginated(
|
response = await datasetio_impl.get_rows_paginated(
|
||||||
# dataset_id="test_dataset",
|
dataset_id="test_dataset",
|
||||||
# rows_in_page=3,
|
rows_in_page=3,
|
||||||
# )
|
)
|
||||||
|
|
||||||
# assert isinstance(response.rows, list)
|
assert isinstance(response.rows, list)
|
||||||
# assert len(response.rows) == 3
|
assert len(response.rows) == 3
|
||||||
# assert response.next_page_token == "3"
|
assert response.next_page_token == "3"
|
||||||
|
|
||||||
# # iterate over all rows
|
# iterate over all rows
|
||||||
# response = await datasetio_impl.get_rows_paginated(
|
response = await datasetio_impl.get_rows_paginated(
|
||||||
# dataset_id="test_dataset",
|
dataset_id="test_dataset",
|
||||||
# rows_in_page=10,
|
rows_in_page=2,
|
||||||
# page_token=response.next_page_token,
|
page_token=response.next_page_token,
|
||||||
# )
|
)
|
||||||
|
|
||||||
# assert isinstance(response.rows, list)
|
assert isinstance(response.rows, list)
|
||||||
# assert len(response.rows) == 10
|
assert len(response.rows) == 2
|
||||||
# assert response.next_page_token == "13"
|
assert response.next_page_token == "5"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -153,4 +141,4 @@ async def test_datasets_validation(datasetio_settings):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
datasets_impl = datasetio_settings["datasets_impl"]
|
datasets_impl = datasetio_settings["datasets_impl"]
|
||||||
await register_local_dataset(datasets_impl)
|
await register_dataset(datasets_impl)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue