From 4253cfcd7f59fedd4747f706db0fff5971e7c48d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 14 Nov 2024 00:08:37 -0500 Subject: [PATCH] local persistent for hf dataset provider (#451) # What does this PR do? - local persistence for HF dataset provider - follow https://github.com/meta-llama/llama-stack/pull/375 ## Test Plan **e2e** 1. fresh llama stack run w/ yaml 2. kill server 3. restart llama stack run w/ yaml ```yaml datasets: - dataset_id: mmlu provider_id: huggingface-0 url: uri: https://huggingface.co/datasets/llamastack/evals metadata: path: llamastack/evals name: evals__mmlu__details split: train dataset_schema: input_query: type: string expected_answer: type: string ``` image ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../adapters/datasetio/huggingface/config.py | 12 +++++++++-- .../datasetio/huggingface/huggingface.py | 20 ++++++++++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/adapters/datasetio/huggingface/config.py b/llama_stack/providers/adapters/datasetio/huggingface/config.py index 89dbe53a0..46470ce49 100644 --- a/llama_stack/providers/adapters/datasetio/huggingface/config.py +++ b/llama_stack/providers/adapters/datasetio/huggingface/config.py @@ -3,7 +3,15 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.datasetio import * # noqa: F401, F403 +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) +from pydantic import BaseModel -class HuggingfaceDatasetIOConfig(BaseModel): ... +class HuggingfaceDatasetIOConfig(BaseModel): + kvstore: KVStoreConfig = SqliteKVStoreConfig( + db_path=(RUNTIME_BASE_DIR / "huggingface_datasetio.db").as_posix() + ) # Uses SQLite config specific to HF storage diff --git a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py index cd143a3ef..8d34df672 100644 --- a/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py @@ -11,9 +11,12 @@ from llama_stack.apis.datasetio import * # noqa: F403 import datasets as hf_datasets from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url +from llama_stack.providers.utils.kvstore import kvstore_impl from .config import HuggingfaceDatasetIOConfig +DATASETS_PREFIX = "datasets:" + def load_hf_dataset(dataset_def: Dataset): if dataset_def.metadata.get("path", None): @@ -33,9 +36,18 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): self.config = config # local registry for keeping track of datasets within the provider self.dataset_infos = {} + self.kvstore = None async def initialize(self) -> None: - pass + self.kvstore = await kvstore_impl(self.config.kvstore) + # Load existing datasets from kvstore + start_key = DATASETS_PREFIX + end_key = f"{DATASETS_PREFIX}\xff" + stored_datasets = await self.kvstore.range(start_key, end_key) + + for dataset in stored_datasets: + dataset = Dataset.model_validate_json(dataset) + self.dataset_infos[dataset.identifier] = dataset async def shutdown(self) -> None: ... @@ -43,6 +55,12 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): self, dataset_def: Dataset, ) -> None: + # Store in kvstore + key = f"{DATASETS_PREFIX}{dataset_def.identifier}" + await self.kvstore.set( + key=key, + value=dataset_def.json(), + ) self.dataset_infos[dataset_def.identifier] = dataset_def async def get_rows_paginated(