more tests

This commit is contained in:
Xi Yan 2025-03-15 16:42:47 -07:00
parent 4fee3af91f
commit 084bacc029
7 changed files with 51 additions and 126 deletions

View file

@ -10,7 +10,7 @@ import pandas
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig
@ -40,11 +40,13 @@ class PandasDataframeDataset:
return
if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_url(self.dataset_def.uri)
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
raise ValueError(
f"Unsupported dataset source type: {self.dataset_def.source.type}"
)
if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
@ -117,4 +119,6 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)

View file

@ -14,14 +14,14 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url
def get_dataframe_from_url(url: URL):
def get_dataframe_from_uri(uri: str):
df = None
if url.uri.endswith(".csv"):
df = pandas.read_csv(url.uri)
elif url.uri.endswith(".xlsx"):
df = pandas.read_excel(url.uri)
elif url.uri.startswith("data:"):
parts = parse_data_url(url.uri)
if uri.endswith(".csv"):
df = pandas.read_csv(uri)
elif uri.endswith(".xlsx"):
df = pandas.read_excel(uri)
elif uri.startswith("data:"):
parts = parse_data_url(uri)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
@ -39,6 +39,6 @@ def get_dataframe_from_url(url: URL):
else:
df = pandas.read_excel(data_bytes)
else:
raise ValueError(f"Unsupported file type: {url}")
raise ValueError(f"Unsupported file type: {uri}")
return df