wip huggingface register

This commit is contained in:
Xi Yan 2024-11-07 15:59:55 -08:00
parent d1633dc412
commit 37d87c585a
5 changed files with 82 additions and 16 deletions

View file

@ -7,13 +7,27 @@ from typing import List, Optional
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from datasets import load_dataset
from datasets import Dataset, load_dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from .config import HuggingfaceDatasetIOConfig from .config import HuggingfaceDatasetIOConfig
from .dataset_defs.llamastack_mmlu import llamastack_mmlu from .dataset_defs.llamastack_mmlu import llamastack_mmlu
def load_hf_dataset(dataset_def: DatasetDef):
if dataset_def.metadata.get("dataset_path", None):
return load_dataset(**dataset_def.metadata)
df = get_dataframe_from_url(dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = Dataset.from_pandas(df)
return dataset
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None: def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
self.config = config self.config = config
@ -31,6 +45,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
self, self,
dataset_def: DatasetDef, dataset_def: DatasetDef,
) -> None: ) -> None:
print("registering dataset", dataset_def)
self.dataset_infos[dataset_def.identifier] = dataset_def self.dataset_infos[dataset_def.identifier] = dataset_def
async def list_datasets(self) -> List[DatasetDef]: async def list_datasets(self) -> List[DatasetDef]:
@ -44,7 +60,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
filter_condition: Optional[str] = None, filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ) -> PaginatedRowsResult:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_dataset(**dataset_def.metadata) loaded_dataset = load_hf_dataset(dataset_def)
if page_token and not page_token.isnumeric(): if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token") raise ValueError("Invalid page_token")

View file

@ -44,11 +44,11 @@ class Testeval:
provider = datasetio_impl.routing_table.get_provider_impl( provider = datasetio_impl.routing_table.get_provider_impl(
"test_dataset_for_eval" "test_dataset_for_eval"
) )
if provider.__provider_spec__.provider_type != "meta-reference": # if provider.__provider_spec__.provider_type != "meta-reference":
pytest.skip("Only meta-reference provider supports registering datasets") # pytest.skip("Only meta-reference provider supports registering datasets")
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) == 1 assert len(response) >= 1
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset_for_eval", dataset_id="test_dataset_for_eval",
rows_in_page=3, rows_in_page=3,

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from llama_models.llama3.api.datatypes import URL
def data_url_from_file(file_path: str) -> URL:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return URL(uri=data_url)

View file

@ -5,22 +5,41 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
import mimetypes import io
import os from urllib.parse import unquote
import pandas
from llama_models.llama3.api.datatypes import URL from llama_models.llama3.api.datatypes import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url
def data_url_from_file(file_path: str) -> URL:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file: def get_dataframe_from_url(url: URL):
file_content = file.read() df = None
if url.uri.endswith(".csv"):
df = pandas.read_csv(url.uri)
elif url.uri.endswith(".xlsx"):
df = pandas.read_excel(url.uri)
elif url.uri.startswith("data:"):
parts = parse_data_url(url.uri)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
base64_content = base64.b64encode(file_content).decode("utf-8") mime_type = parts["mimetype"]
mime_type, _ = mimetypes.guess_type(file_path) mime_category = mime_type.split("/")[0]
data_bytes = io.BytesIO(data)
data_url = f"data:{mime_type};base64,{base64_content}" if mime_category == "text":
df = pandas.read_csv(data_bytes)
else:
df = pandas.read_excel(data_bytes)
else:
raise ValueError(f"Unsupported file type: {url}")
return URL(uri=data_url) return df