mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 10:00:00 +00:00
more tests
This commit is contained in:
parent
4fee3af91f
commit
084bacc029
7 changed files with 51 additions and 126 deletions
|
|
@ -10,7 +10,7 @@ import pandas
|
|||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
|
@ -40,11 +40,13 @@ class PandasDataframeDataset:
|
|||
return
|
||||
|
||||
if self.dataset_def.source.type == "uri":
|
||||
self.df = get_dataframe_from_url(self.dataset_def.uri)
|
||||
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||
elif self.dataset_def.source.type == "rows":
|
||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||
raise ValueError(
|
||||
f"Unsupported dataset source type: {self.dataset_def.source.type}"
|
||||
)
|
||||
|
||||
if self.df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
|
|
@ -117,4 +119,6 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
dataset_impl.load()
|
||||
|
||||
new_rows_df = pandas.DataFrame(rows)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
dataset_impl.df = pandas.concat(
|
||||
[dataset_impl.df, new_rows_df], ignore_index=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,14 +14,14 @@ from llama_stack.apis.common.content_types import URL
|
|||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
|
||||
def get_dataframe_from_url(url: URL):
|
||||
def get_dataframe_from_uri(uri: str):
|
||||
df = None
|
||||
if url.uri.endswith(".csv"):
|
||||
df = pandas.read_csv(url.uri)
|
||||
elif url.uri.endswith(".xlsx"):
|
||||
df = pandas.read_excel(url.uri)
|
||||
elif url.uri.startswith("data:"):
|
||||
parts = parse_data_url(url.uri)
|
||||
if uri.endswith(".csv"):
|
||||
df = pandas.read_csv(uri)
|
||||
elif uri.endswith(".xlsx"):
|
||||
df = pandas.read_excel(uri)
|
||||
elif uri.startswith("data:"):
|
||||
parts = parse_data_url(uri)
|
||||
data = parts["data"]
|
||||
if parts["is_base64"]:
|
||||
data = base64.b64decode(data)
|
||||
|
|
@ -39,6 +39,6 @@ def get_dataframe_from_url(url: URL):
|
|||
else:
|
||||
df = pandas.read_excel(data_bytes)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {url}")
|
||||
raise ValueError(f"Unsupported file type: {uri}")
|
||||
|
||||
return df
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue