mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 20:54:30 +00:00
fix endpoint, only sdk change
This commit is contained in:
parent
13c7c5b6a1
commit
9e6d99f7b1
8 changed files with 161 additions and 72 deletions
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import datasets as hf_datasets
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
|
|
@ -16,24 +18,17 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from .config import HuggingfaceDatasetIOConfig
|
||||
|
||||
DATASETS_PREFIX = "datasets:"
|
||||
from rich.pretty import pprint
|
||||
|
||||
|
||||
def load_hf_dataset(dataset_def: Dataset):
|
||||
if dataset_def.metadata.get("path", None):
|
||||
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
|
||||
else:
|
||||
df = get_dataframe_from_url(dataset_def.url)
|
||||
def parse_hf_params(dataset_def: Dataset):
|
||||
uri = dataset_def.source.uri
|
||||
parsed_uri = urlparse(uri)
|
||||
params = parse_qs(parsed_uri.query)
|
||||
params = {k: v[0] for k, v in params.items()}
|
||||
path = parsed_uri.path.lstrip("/")
|
||||
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
||||
|
||||
dataset = hf_datasets.Dataset.from_pandas(df)
|
||||
|
||||
# drop columns not specified by schema
|
||||
if dataset_def.dataset_schema:
|
||||
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
|
||||
|
||||
return dataset
|
||||
return path, params
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
|
|
@ -60,6 +55,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
self,
|
||||
dataset_def: Dataset,
|
||||
) -> None:
|
||||
print("register_dataset")
|
||||
# Store in kvstore
|
||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
|
|
@ -80,7 +76,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
||||
start_index = start_index or 0
|
||||
|
||||
|
|
@ -98,15 +95,20 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
||||
# Convert rows to HF Dataset format
|
||||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||
|
||||
# Concatenate the new rows with existing dataset
|
||||
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
|
||||
updated_dataset = hf_datasets.concatenate_datasets(
|
||||
[loaded_dataset, new_dataset]
|
||||
)
|
||||
|
||||
if dataset_def.metadata.get("path", None):
|
||||
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
||||
else:
|
||||
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
|
||||
raise NotImplementedError(
|
||||
"Uploading to URL-based datasets is not supported yet"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue