more tests

This commit is contained in:
Xi Yan 2025-03-15 16:42:47 -07:00
parent 4fee3af91f
commit 084bacc029
7 changed files with 51 additions and 126 deletions

View file

@ -10,7 +10,7 @@ import pandas
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig from .config import LocalFSDatasetIOConfig
@ -40,11 +40,13 @@ class PandasDataframeDataset:
return return
if self.dataset_def.source.type == "uri": if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_url(self.dataset_def.uri) self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows": elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows) self.df = pandas.DataFrame(self.dataset_def.source.rows)
else: else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}") raise ValueError(
f"Unsupported dataset source type: {self.dataset_def.source.type}"
)
if self.df is None: if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
@ -117,4 +119,6 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl.load() dataset_impl.load()
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)

View file

@ -14,14 +14,14 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url from llama_stack.providers.utils.memory.vector_store import parse_data_url
def get_dataframe_from_url(url: URL): def get_dataframe_from_uri(uri: str):
df = None df = None
if url.uri.endswith(".csv"): if uri.endswith(".csv"):
df = pandas.read_csv(url.uri) df = pandas.read_csv(uri)
elif url.uri.endswith(".xlsx"): elif uri.endswith(".xlsx"):
df = pandas.read_excel(url.uri) df = pandas.read_excel(uri)
elif url.uri.startswith("data:"): elif uri.startswith("data:"):
parts = parse_data_url(url.uri) parts = parse_data_url(uri)
data = parts["data"] data = parts["data"]
if parts["is_base64"]: if parts["is_base64"]:
data = base64.b64decode(data) data = base64.b64decode(data)
@ -39,6 +39,6 @@ def get_dataframe_from_url(url: URL):
else: else:
df = pandas.read_excel(data_bytes) df = pandas.read_excel(data_bytes)
else: else:
raise ValueError(f"Unsupported file type: {url}") raise ValueError(f"Unsupported file type: {uri}")
return df return df

View file

@ -1,114 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
import pytest
# How to run this test:
#
# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasetio
@pytest.fixture
def dataset_for_test(llama_stack_client):
dataset_id = "test_dataset"
register_dataset(llama_stack_client, dataset_id=dataset_id)
yield
# Teardown - this always runs, even if the test fails
try:
llama_stack_client.datasets.unregister(dataset_id)
except Exception as e:
print(f"Warning: Failed to unregister test_dataset: {e}")
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
def register_dataset(llama_stack_client, for_generation=False, for_rag=False, dataset_id="test_dataset"):
if for_rag:
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
else:
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
if for_generation:
dataset_schema = {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
}
elif for_rag:
dataset_schema = {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"generated_answer": {"type": "string"},
"context": {"type": "string"},
}
else:
dataset_schema = {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"generated_answer": {"type": "string"},
}
dataset_providers = [x for x in llama_stack_client.providers.list() if x.api == "datasetio"]
dataset_provider_id = dataset_providers[0].provider_id
llama_stack_client.datasets.register(
dataset_id=dataset_id,
dataset_schema=dataset_schema,
url=dict(uri=test_url),
provider_id=dataset_provider_id,
)
def test_register_unregister_dataset(llama_stack_client):
register_dataset(llama_stack_client)
response = llama_stack_client.datasets.list()
assert isinstance(response, list)
assert len(response) == 1
assert response[0].identifier == "test_dataset"
llama_stack_client.datasets.unregister("test_dataset")
response = llama_stack_client.datasets.list()
assert isinstance(response, list)
assert len(response) == 0
def test_get_rows_paginated(llama_stack_client, dataset_for_test):
response = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
# iterate over all rows
response = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 2
assert response.next_page_token == "5"

View file

@ -5,6 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64
import mimetypes
import os
import pytest import pytest
# How to run this test: # How to run this test:
@ -12,6 +16,21 @@ import pytest
# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets # LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
@pytest.mark.parametrize( @pytest.mark.parametrize(
"purpose, source, provider_id, limit", "purpose, source, provider_id, limit",
[ [
@ -67,3 +86,19 @@ def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id,
llama_stack_client.datasets.unregister(dataset.identifier) llama_stack_client.datasets.unregister(dataset.identifier)
dataset_list = llama_stack_client.datasets.list() dataset_list = llama_stack_client.datasets.list()
assert dataset.identifier not in [d.identifier for d in dataset_list] assert dataset.identifier not in [d.identifier for d in dataset_list]
def test_register_and_iterrows_from_base64_data(llama_stack_client):
dataset = llama_stack_client.datasets.register(
purpose="eval/messages-answer",
source={
"type": "uri",
"uri": data_url_from_file(
os.path.join(os.path.dirname(__file__), "test_dataset.csv")
),
},
)
assert dataset.identifier is not None
assert dataset.provider_id == "localfs"
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier)
assert len(iterrow_response.data) == 5