From 731faa8dbd68f776c9260957a09fd8162063bc11 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Fri, 18 Jul 2025 12:45:56 -0400 Subject: [PATCH] fix: retry on HF Failure HF Datasets downloads can fail due to ratelimits, if any of the following: FileNotFoundError, RepositoryNotFoundError, EntryNotFoundError, RevisionNotFoundError, GatedRepoError, are raised, retry. These are the exceptions the downloader catches and raises on failures. Signed-off-by: Charlie Doern --- .../datasetio/huggingface/huggingface.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index fafd1d8ff..00746002e 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -3,14 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import time from typing import Any from urllib.parse import parse_qs, urlparse import datasets as hf_datasets +from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Dataset +from llama_stack.log import get_logger from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.pagination import paginate_records @@ -18,6 +21,10 @@ from llama_stack.providers.utils.pagination import paginate_records from .config import HuggingfaceDatasetIOConfig DATASETS_PREFIX = "datasets:" +MAX_RETRIES = 3 +RETRY_DELAY = 5 + +logger = get_logger(name=__name__, category="datasets") def parse_hf_params(dataset_def: Dataset): @@ -75,10 +82,24 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): ) -> PaginatedResponse: dataset_def = self.dataset_infos[dataset_id] path, params = parse_hf_params(dataset_def) - loaded_dataset = hf_datasets.load_dataset(path, **params) - - records = [loaded_dataset[i] for i in range(len(loaded_dataset))] - return paginate_records(records, start_index, limit) + for attempt in range(1, MAX_RETRIES + 1): + try: + loaded_dataset = hf_datasets.load_dataset(path, **params) + records = [loaded_dataset[i] for i in range(len(loaded_dataset))] + return paginate_records(records, start_index, limit) + except ( + FileNotFoundError, + RepositoryNotFoundError, + EntryNotFoundError, + RevisionNotFoundError, + GatedRepoError, + ) as e: + if attempt == MAX_RETRIES: + raise # Re-raise the last exception + logger.error( + f"Attempt {attempt} to download HF Dataset failed with error: {e}. Retrying in {RETRY_DELAY} seconds..." + ) + time.sleep(RETRY_DELAY) async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: dataset_def = self.dataset_infos[dataset_id]