From 37d87c585a7de98451af848458553a9b36437fea Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 15:59:55 -0800 Subject: [PATCH] wip huggingface register --- .../huggingface/datasetio/huggingface.py | 20 ++++++++- llama_stack/providers/tests/eval/test_eval.py | 6 +-- .../providers/utils/datasetio/__init__.py | 5 +++ .../providers/utils/datasetio/url_utils.py | 26 ++++++++++++ .../providers/utils/memory/file_utils.py | 41 ++++++++++++++----- 5 files changed, 82 insertions(+), 16 deletions(-) create mode 100644 llama_stack/providers/utils/datasetio/__init__.py create mode 100644 llama_stack/providers/utils/datasetio/url_utils.py diff --git a/llama_stack/providers/inline/huggingface/datasetio/huggingface.py b/llama_stack/providers/inline/huggingface/datasetio/huggingface.py index 3b8c3049c..1f9ef7cfc 100644 --- a/llama_stack/providers/inline/huggingface/datasetio/huggingface.py +++ b/llama_stack/providers/inline/huggingface/datasetio/huggingface.py @@ -7,13 +7,27 @@ from typing import List, Optional from llama_stack.apis.datasetio import * # noqa: F403 -from datasets import load_dataset + +from datasets import Dataset, load_dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate from .config import HuggingfaceDatasetIOConfig from .dataset_defs.llamastack_mmlu import llamastack_mmlu +def load_hf_dataset(dataset_def: DatasetDef): + if dataset_def.metadata.get("dataset_path", None): + return load_dataset(**dataset_def.metadata) + + df = get_dataframe_from_url(dataset_def.url) + + if df is None: + raise ValueError(f"Failed to load dataset from {dataset_def.url}") + + dataset = Dataset.from_pandas(df) + return dataset + + class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): def __init__(self, config: HuggingfaceDatasetIOConfig) -> None: self.config = config @@ -31,6 +45,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): self, dataset_def: DatasetDef, ) -> None: + + print("registering dataset", dataset_def) self.dataset_infos[dataset_def.identifier] = dataset_def async def list_datasets(self) -> List[DatasetDef]: @@ -44,7 +60,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: dataset_def = self.dataset_infos[dataset_id] - loaded_dataset = load_dataset(**dataset_def.metadata) + loaded_dataset = load_hf_dataset(dataset_def) if page_token and not page_token.isnumeric(): raise ValueError("Invalid page_token") diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index fe09832bb..cb611870b 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -44,11 +44,11 @@ class Testeval: provider = datasetio_impl.routing_table.get_provider_impl( "test_dataset_for_eval" ) - if provider.__provider_spec__.provider_type != "meta-reference": - pytest.skip("Only meta-reference provider supports registering datasets") + # if provider.__provider_spec__.provider_type != "meta-reference": + # pytest.skip("Only meta-reference provider supports registering datasets") response = await datasets_impl.list_datasets() - assert len(response) == 1 + assert len(response) >= 1 rows = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset_for_eval", rows_in_page=3, diff --git a/llama_stack/providers/utils/datasetio/__init__.py b/llama_stack/providers/utils/datasetio/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/utils/datasetio/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py new file mode 100644 index 000000000..bc4462fa0 --- /dev/null +++ b/llama_stack/providers/utils/datasetio/url_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import mimetypes +import os + +from llama_models.llama3.api.datatypes import URL + + +def data_url_from_file(file_path: str) -> URL: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return URL(uri=data_url) diff --git a/llama_stack/providers/utils/memory/file_utils.py b/llama_stack/providers/utils/memory/file_utils.py index bc4462fa0..3faea9f95 100644 --- a/llama_stack/providers/utils/memory/file_utils.py +++ b/llama_stack/providers/utils/memory/file_utils.py @@ -5,22 +5,41 @@ # the root directory of this source tree. import base64 -import mimetypes -import os +import io +from urllib.parse import unquote + +import pandas from llama_models.llama3.api.datatypes import URL +from llama_stack.providers.utils.memory.vector_store import parse_data_url -def data_url_from_file(file_path: str) -> URL: - if not os.path.exists(file_path): - raise FileNotFoundError(f"File not found: {file_path}") - with open(file_path, "rb") as file: - file_content = file.read() +def get_dataframe_from_url(url: URL): + df = None + if url.uri.endswith(".csv"): + df = pandas.read_csv(url.uri) + elif url.uri.endswith(".xlsx"): + df = pandas.read_excel(url.uri) + elif url.uri.startswith("data:"): + parts = parse_data_url(url.uri) + data = parts["data"] + if parts["is_base64"]: + data = base64.b64decode(data) + else: + data = unquote(data) + encoding = parts["encoding"] or "utf-8" + data = data.encode(encoding) - base64_content = base64.b64encode(file_content).decode("utf-8") - mime_type, _ = mimetypes.guess_type(file_path) + mime_type = parts["mimetype"] + mime_category = mime_type.split("/")[0] + data_bytes = io.BytesIO(data) - data_url = f"data:{mime_type};base64,{base64_content}" + if mime_category == "text": + df = pandas.read_csv(data_bytes) + else: + df = pandas.read_excel(data_bytes) + else: + raise ValueError(f"Unsupported file type: {url}") - return URL(uri=data_url) + return df