llama-stack/llama_stack/providers/adapters/datasetio/huggingface/huggingface.py
Xi Yan 4253cfcd7f
local persistent for hf dataset provider (#451)
# What does this PR do?

- local persistence for HF dataset provider
- follow https://github.com/meta-llama/llama-stack/pull/375

## Test Plan

**e2e**
1. fresh llama stack run w/ yaml
2. kill server
3. restart llama stack run w/ yaml

```yaml
datasets:
  - dataset_id: mmlu
    provider_id: huggingface-0
    url:
      uri: https://huggingface.co/datasets/llamastack/evals
    metadata:
      path: llamastack/evals
      name: evals__mmlu__details
      split: train
    dataset_schema:
      input_query:
        type: string
      expected_answer:
        type: string
```
<img width="686" alt="image"
src="https://github.com/user-attachments/assets/d7737931-6a7d-400a-a17d-fef6cbd97eea">


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2024-11-14 00:08:37 -05:00

96 lines
3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_stack.apis.datasetio import * # noqa: F403
import datasets as hf_datasets
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:"
def load_hf_dataset(dataset_def: Dataset):
if dataset_def.metadata.get("path", None):
return hf_datasets.load_dataset(**dataset_def.metadata)
df = get_dataframe_from_url(dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = hf_datasets.Dataset.from_pandas(df)
return dataset
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos = {}
self.kvstore = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
self.dataset_infos[dataset.identifier] = dataset
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset_def: Dataset,
) -> None:
# Store in kvstore
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset_def.json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
end = len(loaded_dataset)
else:
end = min(start + rows_in_page, len(loaded_dataset))
rows = [loaded_dataset[i] for i in range(start, end)]
return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
)