diff --git a/llama_stack/providers/adapters/datasetio/huggingface/config.py b/llama_stack/providers/adapters/datasetio/huggingface/config.py index 89dbe53a0..46470ce49 100644 --- a/llama_stack/providers/adapters/datasetio/huggingface/config.py +++ b/llama_stack/providers/adapters/datasetio/huggingface/config.py @@ -3,7 +3,15 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.datasetio import * # noqa: F401, F403 +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) +from pydantic import BaseModel -class HuggingfaceDatasetIOConfig(BaseModel): ... +class HuggingfaceDatasetIOConfig(BaseModel): + kvstore: KVStoreConfig = SqliteKVStoreConfig( + db_path=(RUNTIME_BASE_DIR / "huggingface_datasetio.db").as_posix() + ) # Uses SQLite config specific to HF storage diff --git a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py index cd143a3ef..8d34df672 100644 --- a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py @@ -11,9 +11,12 @@ from llama_stack.apis.datasetio import * # noqa: F403 import datasets as hf_datasets from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url +from llama_stack.providers.utils.kvstore import kvstore_impl from .config import HuggingfaceDatasetIOConfig +DATASETS_PREFIX = "datasets:" + def load_hf_dataset(dataset_def: Dataset): if dataset_def.metadata.get("path", None): @@ -33,9 +36,18 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): self.config = config # local registry for keeping track of datasets within the provider self.dataset_infos = {} + self.kvstore = None async def initialize(self) -> None: - pass + self.kvstore = await kvstore_impl(self.config.kvstore) + # Load existing datasets from kvstore + start_key = DATASETS_PREFIX + end_key = f"{DATASETS_PREFIX}\xff" + stored_datasets = await self.kvstore.range(start_key, end_key) + + for dataset in stored_datasets: + dataset = Dataset.model_validate_json(dataset) + self.dataset_infos[dataset.identifier] = dataset async def shutdown(self) -> None: ... @@ -43,6 +55,12 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): self, dataset_def: Dataset, ) -> None: + # Store in kvstore + key = f"{DATASETS_PREFIX}{dataset_def.identifier}" + await self.kvstore.set( + key=key, + value=dataset_def.json(), + ) self.dataset_infos[dataset_def.identifier] = dataset_def async def get_rows_paginated(