From 6525b439064a3dbe79e011e1167de8ae31cdb1f9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 18:41:33 -0800 Subject: [PATCH] refactor --- .../meta_reference/datasetio/datasetio.py | 33 +++---------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py b/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py index a96d9bcab..a6fe4feb3 100644 --- a/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py @@ -3,20 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import io from typing import List, Optional import pandas from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 -import base64 from abc import ABC, abstractmethod from dataclasses import dataclass -from urllib.parse import unquote from llama_stack.providers.datatypes import DatasetsProtocolPrivate -from llama_stack.providers.utils.memory.vector_store import parse_data_url +from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from .config import MetaReferenceDatasetIOConfig @@ -73,31 +70,9 @@ class PandasDataframeDataset(BaseDataset): if self.df is not None: return - # TODO: more robust support w/ data url - if self.dataset_def.url.uri.endswith(".csv"): - df = pandas.read_csv(self.dataset_def.url.uri) - elif self.dataset_def.url.uri.endswith(".xlsx"): - df = pandas.read_excel(self.dataset_def.url.uri) - elif self.dataset_def.url.uri.startswith("data:"): - parts = parse_data_url(self.dataset_def.url.uri) - data = parts["data"] - if parts["is_base64"]: - data = base64.b64decode(data) - else: - data = unquote(data) - encoding = parts["encoding"] or "utf-8" - data = data.encode(encoding) - - mime_type = parts["mimetype"] - mime_category = mime_type.split("/")[0] - data_bytes = io.BytesIO(data) - - if mime_category == "text": - df = pandas.read_csv(data_bytes) - else: - df = pandas.read_excel(data_bytes) - else: - raise ValueError(f"Unsupported file type: {self.dataset_def.url}") + df = get_dataframe_from_url(self.dataset_def.url) + if df is None: + raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") self.df = self._validate_dataset_schema(df)