From 4fee3af91ffed2ec695f9f9546a12af1ce3b993c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 16:37:46 -0700 Subject: [PATCH] unregister --- .../inline/datasetio/localfs/datasetio.py | 110 ++++-------------- .../datasetio/huggingface/huggingface.py | 3 +- tests/integration/datasets/test_datasets.py | 42 ++++++- tests/integration/datasets/test_script.py | 20 +++- 4 files changed, 79 insertions(+), 96 deletions(-) diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index afa9ee0ff..26fb97072 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,16 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 -import os -from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any, Dict, List, Optional -from urllib.parse import urlparse import pandas -from llama_stack.apis.common.content_types import URL from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate @@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig DATASETS_PREFIX = "localfs_datasets:" -class BaseDataset(ABC): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError() - - @abstractmethod - def __getitem__(self, idx): - raise NotImplementedError() - - @abstractmethod - def load(self): - raise NotImplementedError() - - -@dataclass -class DatasetInfo: - dataset_def: Dataset - dataset_impl: BaseDataset - - -class PandasDataframeDataset(BaseDataset): +class PandasDataframeDataset: def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.dataset_def = dataset_def @@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset): else: return self.df.iloc[idx].to_dict() - def _validate_dataset_schema(self, df) -> pandas.DataFrame: - # note that we will drop any columns in dataset that are not in the schema - df = df[self.dataset_def.dataset_schema.keys()] - # check all columns in dataset schema are present - assert len(df.columns) == len(self.dataset_def.dataset_schema) - # TODO: type checking against column types in dataset schema - return df - def load(self) -> None: if self.df is not None: return - df = get_dataframe_from_url(self.dataset_def.url) - if df is None: - raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") + if self.dataset_def.source.type == "uri": + self.df = get_dataframe_from_url(self.dataset_def.uri) + elif self.dataset_def.source.type == "rows": + self.df = pandas.DataFrame(self.dataset_def.source.rows) + else: + raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}") - self.df = self._validate_dataset_schema(df) + if self.df is None: + raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): @@ -99,29 +66,21 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): for dataset in stored_datasets: dataset = Dataset.model_validate_json(dataset) - dataset_impl = PandasDataframeDataset(dataset) - self.dataset_infos[dataset.identifier] = DatasetInfo( - dataset_def=dataset, - dataset_impl=dataset_impl, - ) + self.dataset_infos[dataset.identifier] = dataset async def shutdown(self) -> None: ... async def register_dataset( self, - dataset: Dataset, + dataset_def: Dataset, ) -> None: # Store in kvstore - key = f"{DATASETS_PREFIX}{dataset.identifier}" + key = f"{DATASETS_PREFIX}{dataset_def.identifier}" await self.kvstore.set( key=key, - value=dataset.json(), - ) - dataset_impl = PandasDataframeDataset(dataset) - self.dataset_infos[dataset.identifier] = DatasetInfo( - dataset_def=dataset, - dataset_impl=dataset_impl, + value=dataset_def.model_dump_json(), ) + self.dataset_infos[dataset_def.identifier] = dataset_def async def unregister_dataset(self, dataset_id: str) -> None: key = f"{DATASETS_PREFIX}{dataset_id}" @@ -134,51 +93,28 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): start_index: Optional[int] = None, limit: Optional[int] = None, ) -> IterrowsResponse: - dataset_info = self.dataset_infos.get(dataset_id) - dataset_info.dataset_impl.load() + dataset_def = self.dataset_infos[dataset_id] + dataset_impl = PandasDataframeDataset(dataset_def) + dataset_impl.load() start_index = start_index or 0 if limit is None or limit == -1: - end = len(dataset_info.dataset_impl) + end = len(dataset_impl) else: - end = min(start_index + limit, len(dataset_info.dataset_impl)) + end = min(start_index + limit, len(dataset_impl)) - rows = dataset_info.dataset_impl[start_index:end] + rows = dataset_impl[start_index:end] return IterrowsResponse( data=rows, - next_index=end if end < len(dataset_info.dataset_impl) else None, + next_index=end if end < len(dataset_impl) else None, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: - dataset_info = self.dataset_infos.get(dataset_id) - if dataset_info is None: - raise ValueError(f"Dataset with id {dataset_id} not found") - - dataset_impl = dataset_info.dataset_impl + dataset_def = self.dataset_infos[dataset_id] + dataset_impl = PandasDataframeDataset(dataset_def) dataset_impl.load() new_rows_df = pandas.DataFrame(rows) - new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) - - url = str(dataset_info.dataset_def.url.uri) - parsed_url = urlparse(url) - - if parsed_url.scheme == "file" or not parsed_url.scheme: - file_path = parsed_url.path - os.makedirs(os.path.dirname(file_path), exist_ok=True) - dataset_impl.df.to_csv(file_path, index=False) - elif parsed_url.scheme == "data": - # For data URLs, we need to update the base64-encoded content - if not parsed_url.path.startswith("text/csv;base64,"): - raise ValueError("Data URL must be a base64-encoded CSV") - - csv_buffer = dataset_impl.df.to_csv(index=False) - base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8") - dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}") - else: - raise ValueError( - f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." - ) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 23a6cffb5..41ce747f7 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -52,12 +52,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): self, dataset_def: Dataset, ) -> None: - print("register_dataset") # Store in kvstore key = f"{DATASETS_PREFIX}{dataset_def.identifier}" await self.kvstore.set( key=key, - value=dataset_def.json(), + value=dataset_def.model_dump_json(), ) self.dataset_infos[dataset_def.identifier] = dataset_def diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py index fdae5420c..dea0ac2aa 100644 --- a/tests/integration/datasets/test_datasets.py +++ b/tests/integration/datasets/test_datasets.py @@ -13,7 +13,7 @@ import pytest @pytest.mark.parametrize( - "purpose, source, provider_id", + "purpose, source, provider_id, limit", [ ( "eval/messages-answer", @@ -22,16 +22,48 @@ import pytest "uri": "huggingface://datasets/llamastack/simpleqa?split=train", }, "huggingface", + 10, + ), + ( + "eval/messages-answer", + { + "type": "rows", + "rows": [ + { + "messages": [{"role": "user", "content": "Hello, world!"}], + "answer": "Hello, world!", + }, + { + "messages": [ + { + "role": "user", + "content": "What is the capital of France?", + } + ], + "answer": "Paris", + }, + ], + }, + "localfs", + 2, ), ], ) -def test_register_dataset(llama_stack_client, purpose, source, provider_id): +def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit): dataset = llama_stack_client.datasets.register( purpose=purpose, source=source, ) assert dataset.identifier is not None assert dataset.provider_id == provider_id - iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=10) - assert len(iterrow_response.data) == 10 - assert iterrow_response.next_index is not None + iterrow_response = llama_stack_client.datasets.iterrows( + dataset.identifier, limit=limit + ) + assert len(iterrow_response.data) == limit + + dataset_list = llama_stack_client.datasets.list() + assert dataset.identifier in [d.identifier for d in dataset_list] + + llama_stack_client.datasets.unregister(dataset.identifier) + dataset_list = llama_stack_client.datasets.list() + assert dataset.identifier not in [d.identifier for d in dataset_list] diff --git a/tests/integration/datasets/test_script.py b/tests/integration/datasets/test_script.py index 99bd27460..a1587a09e 100644 --- a/tests/integration/datasets/test_script.py +++ b/tests/integration/datasets/test_script.py @@ -10,11 +10,27 @@ from rich.pretty import pprint def test_register_dataset(): client = LlamaStackClient(base_url="http://localhost:8321") + # dataset = client.datasets.register( + # purpose="eval/messages-answer", + # source={ + # "type": "uri", + # "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + # }, + # ) dataset = client.datasets.register( purpose="eval/messages-answer", source={ - "type": "uri", - "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + "type": "rows", + "rows": [ + { + "messages": [{"role": "user", "content": "Hello, world!"}], + "answer": "Hello, world!", + }, + { + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "answer": "Paris", + }, + ], }, ) dataset_id = dataset.identifier