better test

This commit is contained in:
Xi Yan 2025-03-15 16:45:01 -07:00
parent 8afadc6829
commit bb25509c37
3 changed files with 4 additions and 13 deletions

View file

@ -44,9 +44,7 @@ class PandasDataframeDataset:
elif self.dataset_def.source.type == "rows": elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows) self.df = pandas.DataFrame(self.dataset_def.source.rows)
else: else:
raise ValueError( raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
f"Unsupported dataset source type: {self.dataset_def.source.type}"
)
if self.df is None: if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
@ -119,6 +117,4 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl.load() dataset_impl.load()
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat( dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
[dataset_impl.df, new_rows_df], ignore_index=True
)

View file

@ -10,7 +10,6 @@ from urllib.parse import unquote
import pandas import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url from llama_stack.providers.utils.memory.vector_store import parse_data_url

View file

@ -70,9 +70,7 @@ def data_url_from_file(file_path: str) -> str:
"eval/messages-answer", "eval/messages-answer",
{ {
"type": "uri", "type": "uri",
"uri": data_url_from_file( "uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")),
os.path.join(os.path.dirname(__file__), "test_dataset.csv")
),
}, },
"localfs", "localfs",
5, 5,
@ -86,9 +84,7 @@ def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id,
) )
assert dataset.identifier is not None assert dataset.identifier is not None
assert dataset.provider_id == provider_id assert dataset.provider_id == provider_id
iterrow_response = llama_stack_client.datasets.iterrows( iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit)
dataset.identifier, limit=limit
)
assert len(iterrow_response.data) == limit assert len(iterrow_response.data) == limit
dataset_list = llama_stack_client.datasets.list() dataset_list = llama_stack_client.datasets.list()