mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
wip huggingface register
This commit is contained in:
parent
d1633dc412
commit
37d87c585a
5 changed files with 82 additions and 16 deletions
|
@ -7,13 +7,27 @@ from typing import List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
from datasets import Dataset, load_dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
|
|
||||||
from .config import HuggingfaceDatasetIOConfig
|
from .config import HuggingfaceDatasetIOConfig
|
||||||
from .dataset_defs.llamastack_mmlu import llamastack_mmlu
|
from .dataset_defs.llamastack_mmlu import llamastack_mmlu
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_dataset(dataset_def: DatasetDef):
|
||||||
|
if dataset_def.metadata.get("dataset_path", None):
|
||||||
|
return load_dataset(**dataset_def.metadata)
|
||||||
|
|
||||||
|
df = get_dataframe_from_url(dataset_def.url)
|
||||||
|
|
||||||
|
if df is None:
|
||||||
|
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
||||||
|
|
||||||
|
dataset = Dataset.from_pandas(df)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
|
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -31,6 +45,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
dataset_def: DatasetDef,
|
dataset_def: DatasetDef,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
print("registering dataset", dataset_def)
|
||||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||||
|
|
||||||
async def list_datasets(self) -> List[DatasetDef]:
|
async def list_datasets(self) -> List[DatasetDef]:
|
||||||
|
@ -44,7 +60,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
filter_condition: Optional[str] = None,
|
filter_condition: Optional[str] = None,
|
||||||
) -> PaginatedRowsResult:
|
) -> PaginatedRowsResult:
|
||||||
dataset_def = self.dataset_infos[dataset_id]
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
loaded_dataset = load_dataset(**dataset_def.metadata)
|
loaded_dataset = load_hf_dataset(dataset_def)
|
||||||
|
|
||||||
if page_token and not page_token.isnumeric():
|
if page_token and not page_token.isnumeric():
|
||||||
raise ValueError("Invalid page_token")
|
raise ValueError("Invalid page_token")
|
||||||
|
|
|
@ -44,11 +44,11 @@ class Testeval:
|
||||||
provider = datasetio_impl.routing_table.get_provider_impl(
|
provider = datasetio_impl.routing_table.get_provider_impl(
|
||||||
"test_dataset_for_eval"
|
"test_dataset_for_eval"
|
||||||
)
|
)
|
||||||
if provider.__provider_spec__.provider_type != "meta-reference":
|
# if provider.__provider_spec__.provider_type != "meta-reference":
|
||||||
pytest.skip("Only meta-reference provider supports registering datasets")
|
# pytest.skip("Only meta-reference provider supports registering datasets")
|
||||||
|
|
||||||
response = await datasets_impl.list_datasets()
|
response = await datasets_impl.list_datasets()
|
||||||
assert len(response) == 1
|
assert len(response) >= 1
|
||||||
rows = await datasetio_impl.get_rows_paginated(
|
rows = await datasetio_impl.get_rows_paginated(
|
||||||
dataset_id="test_dataset_for_eval",
|
dataset_id="test_dataset_for_eval",
|
||||||
rows_in_page=3,
|
rows_in_page=3,
|
||||||
|
|
5
llama_stack/providers/utils/datasetio/__init__.py
Normal file
5
llama_stack/providers/utils/datasetio/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
26
llama_stack/providers/utils/datasetio/url_utils.py
Normal file
26
llama_stack/providers/utils/datasetio/url_utils.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
|
|
||||||
|
def data_url_from_file(file_path: str) -> URL:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
file_content = file.read()
|
||||||
|
|
||||||
|
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
|
|
||||||
|
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||||
|
|
||||||
|
return URL(uri=data_url)
|
|
@ -5,22 +5,41 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import mimetypes
|
import io
|
||||||
import os
|
from urllib.parse import unquote
|
||||||
|
|
||||||
|
import pandas
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||||
|
|
||||||
def data_url_from_file(file_path: str) -> URL:
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
with open(file_path, "rb") as file:
|
def get_dataframe_from_url(url: URL):
|
||||||
file_content = file.read()
|
df = None
|
||||||
|
if url.uri.endswith(".csv"):
|
||||||
|
df = pandas.read_csv(url.uri)
|
||||||
|
elif url.uri.endswith(".xlsx"):
|
||||||
|
df = pandas.read_excel(url.uri)
|
||||||
|
elif url.uri.startswith("data:"):
|
||||||
|
parts = parse_data_url(url.uri)
|
||||||
|
data = parts["data"]
|
||||||
|
if parts["is_base64"]:
|
||||||
|
data = base64.b64decode(data)
|
||||||
|
else:
|
||||||
|
data = unquote(data)
|
||||||
|
encoding = parts["encoding"] or "utf-8"
|
||||||
|
data = data.encode(encoding)
|
||||||
|
|
||||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
mime_type = parts["mimetype"]
|
||||||
mime_type, _ = mimetypes.guess_type(file_path)
|
mime_category = mime_type.split("/")[0]
|
||||||
|
data_bytes = io.BytesIO(data)
|
||||||
|
|
||||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
if mime_category == "text":
|
||||||
|
df = pandas.read_csv(data_bytes)
|
||||||
|
else:
|
||||||
|
df = pandas.read_excel(data_bytes)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file type: {url}")
|
||||||
|
|
||||||
return URL(uri=data_url)
|
return df
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue