mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
wip huggingface register
This commit is contained in:
parent
d1633dc412
commit
37d87c585a
5 changed files with 82 additions and 16 deletions
|
@ -7,13 +7,27 @@ from typing import List, Optional
|
|||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
|
||||
from .config import HuggingfaceDatasetIOConfig
|
||||
from .dataset_defs.llamastack_mmlu import llamastack_mmlu
|
||||
|
||||
|
||||
def load_hf_dataset(dataset_def: DatasetDef):
|
||||
if dataset_def.metadata.get("dataset_path", None):
|
||||
return load_dataset(**dataset_def.metadata)
|
||||
|
||||
df = get_dataframe_from_url(dataset_def.url)
|
||||
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
||||
|
||||
dataset = Dataset.from_pandas(df)
|
||||
return dataset
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
|
||||
self.config = config
|
||||
|
@ -31,6 +45,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
self,
|
||||
dataset_def: DatasetDef,
|
||||
) -> None:
|
||||
|
||||
print("registering dataset", dataset_def)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
async def list_datasets(self) -> List[DatasetDef]:
|
||||
|
@ -44,7 +60,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_dataset(**dataset_def.metadata)
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
|
|
|
@ -44,11 +44,11 @@ class Testeval:
|
|||
provider = datasetio_impl.routing_table.get_provider_impl(
|
||||
"test_dataset_for_eval"
|
||||
)
|
||||
if provider.__provider_spec__.provider_type != "meta-reference":
|
||||
pytest.skip("Only meta-reference provider supports registering datasets")
|
||||
# if provider.__provider_spec__.provider_type != "meta-reference":
|
||||
# pytest.skip("Only meta-reference provider supports registering datasets")
|
||||
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
assert len(response) >= 1
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset_for_eval",
|
||||
rows_in_page=3,
|
||||
|
|
5
llama_stack/providers/utils/datasetio/__init__.py
Normal file
5
llama_stack/providers/utils/datasetio/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
26
llama_stack/providers/utils/datasetio/url_utils.py
Normal file
26
llama_stack/providers/utils/datasetio/url_utils.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> URL:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return URL(uri=data_url)
|
|
@ -5,22 +5,41 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
import io
|
||||
from urllib.parse import unquote
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
def data_url_from_file(file_path: str) -> URL:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
def get_dataframe_from_url(url: URL):
|
||||
df = None
|
||||
if url.uri.endswith(".csv"):
|
||||
df = pandas.read_csv(url.uri)
|
||||
elif url.uri.endswith(".xlsx"):
|
||||
df = pandas.read_excel(url.uri)
|
||||
elif url.uri.startswith("data:"):
|
||||
parts = parse_data_url(url.uri)
|
||||
data = parts["data"]
|
||||
if parts["is_base64"]:
|
||||
data = base64.b64decode(data)
|
||||
else:
|
||||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
mime_type = parts["mimetype"]
|
||||
mime_category = mime_type.split("/")[0]
|
||||
data_bytes = io.BytesIO(data)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
if mime_category == "text":
|
||||
df = pandas.read_csv(data_bytes)
|
||||
else:
|
||||
df = pandas.read_excel(data_bytes)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {url}")
|
||||
|
||||
return URL(uri=data_url)
|
||||
return df
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue