This commit is contained in:
Charlie Doern 2025-08-04 18:47:58 +00:00 committed by GitHub
commit cfa703f107
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,14 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any
from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets
from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
@ -18,6 +21,10 @@ from llama_stack.providers.utils.pagination import paginate_records
from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:"
MAX_RETRIES = 3
RETRY_DELAY = 5
logger = get_logger(name=__name__, category="datasets")
def parse_hf_params(dataset_def: Dataset):
@ -75,10 +82,24 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
) -> PaginatedResponse:
dataset_def = self.dataset_infos[dataset_id]
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
records = [loaded_dataset[i] for i in range(len(loaded_dataset))]
return paginate_records(records, start_index, limit)
for attempt in range(1, MAX_RETRIES + 1):
try:
loaded_dataset = hf_datasets.load_dataset(path, **params)
records = [loaded_dataset[i] for i in range(len(loaded_dataset))]
return paginate_records(records, start_index, limit)
except (
FileNotFoundError,
RepositoryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError,
GatedRepoError,
) as e:
if attempt == MAX_RETRIES:
raise # Re-raise the last exception
logger.error(
f"Attempt {attempt} to download HF Dataset failed with error: {e}. Retrying in {RETRY_DELAY} seconds..."
)
time.sleep(RETRY_DELAY)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]