fix endpoint, only sdk change

This commit is contained in:
Xi Yan 2025-03-15 16:15:45 -07:00
parent 13c7c5b6a1
commit 9e6d99f7b1
8 changed files with 161 additions and 72 deletions

View file

@ -5,6 +5,8 @@
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
@ -16,24 +18,17 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:"
from rich.pretty import pprint
def load_hf_dataset(dataset_def: Dataset):
if dataset_def.metadata.get("path", None):
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
else:
df = get_dataframe_from_url(dataset_def.url)
def parse_hf_params(dataset_def: Dataset):
uri = dataset_def.source.uri
parsed_uri = urlparse(uri)
params = parse_qs(parsed_uri.query)
params = {k: v[0] for k, v in params.items()}
path = parsed_uri.path.lstrip("/")
if df is None:
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = hf_datasets.Dataset.from_pandas(df)
# drop columns not specified by schema
if dataset_def.dataset_schema:
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
return dataset
return path, params
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -60,6 +55,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
self,
dataset_def: Dataset,
) -> None:
print("register_dataset")
# Store in kvstore
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
@ -80,7 +76,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
start_index = start_index or 0
@ -98,15 +95,20 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
# Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
updated_dataset = hf_datasets.concatenate_datasets(
[loaded_dataset, new_dataset]
)
if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"])
else:
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
raise NotImplementedError(
"Uploading to URL-based datasets is not supported yet"
)