mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
refactor
This commit is contained in:
parent
edeb6dcf04
commit
6525b43906
1 changed files with 4 additions and 29 deletions
|
@ -3,20 +3,17 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import io
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
import base64
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from urllib.parse import unquote
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
|
|
||||||
from .config import MetaReferenceDatasetIOConfig
|
from .config import MetaReferenceDatasetIOConfig
|
||||||
|
|
||||||
|
@ -73,31 +70,9 @@ class PandasDataframeDataset(BaseDataset):
|
||||||
if self.df is not None:
|
if self.df is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: more robust support w/ data url
|
df = get_dataframe_from_url(self.dataset_def.url)
|
||||||
if self.dataset_def.url.uri.endswith(".csv"):
|
if df is None:
|
||||||
df = pandas.read_csv(self.dataset_def.url.uri)
|
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||||
elif self.dataset_def.url.uri.endswith(".xlsx"):
|
|
||||||
df = pandas.read_excel(self.dataset_def.url.uri)
|
|
||||||
elif self.dataset_def.url.uri.startswith("data:"):
|
|
||||||
parts = parse_data_url(self.dataset_def.url.uri)
|
|
||||||
data = parts["data"]
|
|
||||||
if parts["is_base64"]:
|
|
||||||
data = base64.b64decode(data)
|
|
||||||
else:
|
|
||||||
data = unquote(data)
|
|
||||||
encoding = parts["encoding"] or "utf-8"
|
|
||||||
data = data.encode(encoding)
|
|
||||||
|
|
||||||
mime_type = parts["mimetype"]
|
|
||||||
mime_category = mime_type.split("/")[0]
|
|
||||||
data_bytes = io.BytesIO(data)
|
|
||||||
|
|
||||||
if mime_category == "text":
|
|
||||||
df = pandas.read_csv(data_bytes)
|
|
||||||
else:
|
|
||||||
df = pandas.read_excel(data_bytes)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
|
|
||||||
|
|
||||||
self.df = self._validate_dataset_schema(df)
|
self.df = self._validate_dataset_schema(df)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue