From f8d9e4f60f13d305c8325c2cbed20910c1ed6281 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 22 Oct 2024 13:09:17 -0700 Subject: [PATCH] dataset datasetio --- llama_stack/apis/datasetio/datasetio.py | 4 +- .../distribution/routers/routing_tables.py | 15 +++ .../meta_reference/datasetio/datasetio.py | 97 ++++++++++++++-- llama_stack/providers/registry/datasetio.py | 2 +- .../providers/tests/datasetio/__init__.py | 5 + .../datasetio/provider_config_example.yaml | 4 + .../tests/datasetio/test_datasetio.py | 109 ++++++++++++++++++ .../utils/datasetio/dataset_utils.py | 23 ++++ 8 files changed, 249 insertions(+), 10 deletions(-) create mode 100644 llama_stack/providers/tests/datasetio/__init__.py create mode 100644 llama_stack/providers/tests/datasetio/provider_config_example.yaml create mode 100644 llama_stack/providers/tests/datasetio/test_datasetio.py create mode 100644 llama_stack/providers/utils/datasetio/dataset_utils.py diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index e8811d233..366af8602 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -17,7 +17,7 @@ class PaginatedRowsResult(BaseModel): # the rows obey the DatasetSchema for the given dataset rows: List[Dict[str, Any]] total_count: int - next_page_token: Optional[str] = None + next_page_token: Optional[int] = None class DatasetStore(Protocol): @@ -34,6 +34,6 @@ class DatasetIO(Protocol): self, dataset_id: str, rows_in_page: int, - page_token: Optional[str] = None, + page_token: Optional[int] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 1f44ca08b..6656c34af 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -28,6 +28,10 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: await p.register_shield(obj) elif api == Api.memory: await p.register_memory_bank(obj) + elif api == Api.datasetio: + await p.register_dataset(obj) + else: + raise ValueError(f"Unknown API {api} for registering object with provider") Registry = Dict[str, List[RoutableObjectWithProvider]] @@ -81,6 +85,16 @@ class CommonRoutingTableImpl(RoutingTable): add_objects(memory_banks) + elif api == Api.datasetio: + p.dataset_store = self + datasets = await p.list_datasets() + + # do in-memory updates due to pesky Annotated unions + for d in datasets: + d.provider_id = pid + + add_objects(datasets) + async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): await p.shutdown() @@ -138,6 +152,7 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider `{obj.provider_id}` not found") p = self.impls_by_provider_id[obj.provider_id] + await register_object_with_provider(obj, p) if obj.identifier not in self.registry: diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index d914c0515..57bfefc31 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -5,17 +5,83 @@ # the root directory of this source tree. from typing import List, Optional +import pandas + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +from dataclasses import dataclass + from llama_stack.providers.datatypes import DatasetsProtocolPrivate +from llama_stack.providers.utils.datasetio.dataset_utils import BaseDataset from .config import MetaReferenceDatasetIOConfig +@dataclass +class DatasetInfo: + dataset_def: DatasetDef + dataset_impl: BaseDataset + next_page_token: Optional[int] = None + + +class CustomDataset(BaseDataset): + def __init__(self, dataset_def: DatasetDef, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.dataset_def = dataset_def + # TODO: validate dataset_def against schema + self.df = None + self.load() + + def __len__(self) -> int: + if self.df is None: + self.load() + return len(self.df) + + def __getitem__(self, idx): + if isinstance(idx, slice): + return self.df.iloc[idx].to_dict(orient="records") + else: + return self.df.iloc[idx].to_dict() + + def load(self) -> None: + if self.df: + return + + # TODO: more robust support w/ data url + if self.dataset_def.url.uri.endswith(".csv"): + df = pandas.read_csv(self.dataset_def.url.uri) + elif self.dataset_def.url.uri.endswith(".xlsx"): + df = pandas.read_excel(self.dataset_def.url.uri) + elif self.dataset_def.url.uri.startswith("data:"): + parts = parse_data_url(self.dataset_def.url.uri) + data = parts["data"] + if parts["is_base64"]: + data = base64.b64decode(data) + else: + data = unquote(data) + encoding = parts["encoding"] or "utf-8" + data = data.encode(encoding) + + mime_type = parts["mimetype"] + mime_category = mime_type.split("/")[0] + data_bytes = io.BytesIO(data) + + if mime_category == "text": + df = pandas.read_csv(data_bytes) + else: + df = pandas.read_excel(data_bytes) + else: + raise ValueError(f"Unsupported file type: {self.dataset_def.url}") + + self.df = df + + class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate): def __init__(self, config: MetaReferenceDatasetIOConfig) -> None: self.config = config + # local registry for keeping track of datasets within the provider + self.dataset_infos = {} async def initialize(self) -> None: ... @@ -23,21 +89,38 @@ class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate): async def register_dataset( self, - memory_bank: DatasetDef, + dataset_def: DatasetDef, ) -> None: - print("register dataset") + self.dataset_infos[dataset_def.identifier] = DatasetInfo( + dataset_def=dataset_def, + dataset_impl=CustomDataset(dataset_def), + next_page_token=0, + ) async def list_datasets(self) -> List[DatasetDef]: - print("list datasets") - return [] + return [i.dataset_def for i in self.dataset_infos.values()] async def get_rows_paginated( self, dataset_id: str, rows_in_page: int, - page_token: Optional[str] = None, + page_token: Optional[int] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: - print("get rows paginated") + dataset_info = self.dataset_infos.get(dataset_id) + if page_token is None: + dataset_info.next_page_token = 0 - return PaginatedRowsResult(rows=[], total_count=1, next_page_token=None) + if rows_in_page == -1: + rows = dataset_info.dataset_impl[dataset_info.next_page_token :] + + start = dataset_info.next_page_token + end = min(start + rows_in_page, len(dataset_info.dataset_impl)) + rows = dataset_info.dataset_impl[start:end] + dataset_info.next_page_token = end + + return PaginatedRowsResult( + rows=rows, + total_count=len(rows), + next_page_token=dataset_info.next_page_token, + ) diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 63da2a2ee..89a0a38f6 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.datasetio, provider_type="meta-reference", - pip_packages=[], + pip_packages=["pandas"], module="llama_stack.providers.impls.meta_reference.datasetio", config_class="llama_stack.providers.impls.meta_reference.datasetio.MetaReferenceDatasetIOConfig", api_dependencies=[], diff --git a/llama_stack/providers/tests/datasetio/__init__.py b/llama_stack/providers/tests/datasetio/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/datasetio/provider_config_example.yaml b/llama_stack/providers/tests/datasetio/provider_config_example.yaml new file mode 100644 index 000000000..c0565a39e --- /dev/null +++ b/llama_stack/providers/tests/datasetio/provider_config_example.yaml @@ -0,0 +1,4 @@ +providers: + - provider_id: test-meta + provider_type: meta-reference + config: {} diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py new file mode 100644 index 000000000..0d8b6e49b --- /dev/null +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import os + +import pytest +import pytest_asyncio + +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.tests.resolver import resolve_impls_for_test + +# How to run this test: +# +# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky +# since it depends on the provider you are testing. On top of that you need +# `pytest` and `pytest-asyncio` installed. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/datasetio/test_datasetio.py \ +# --tb=short --disable-warnings +# ``` + + +@pytest_asyncio.fixture(scope="session") +async def datasetio_settings(): + impls = await resolve_impls_for_test( + Api.datasetio, + ) + return { + "datasetio_impl": impls[Api.datasetio], + "datasets_impl": impls[Api.datasets], + } + + +async def register_dataset(datasets_impl: Datasets): + dataset = DatasetDefWithProvider( + identifier="test_dataset", + provider_id=os.environ["PROVIDER_ID"], + url=URL( + uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + ), + columns_schema={}, + ) + await datasets_impl.register_dataset(dataset) + + +@pytest.mark.asyncio +async def test_datasets_list(datasetio_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + datasets_impl = datasetio_settings["datasets_impl"] + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 0 + + +@pytest.mark.asyncio +async def test_datasets_register(datasetio_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + datasets_impl = datasetio_settings["datasets_impl"] + await register_dataset(datasets_impl) + + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 1 + + # register same dataset with same id again will fail + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 1 + assert response[0].identifier == "test_dataset" + + +@pytest.mark.asyncio +async def test_get_rows_paginated(datasetio_settings): + datasetio_impl = datasetio_settings["datasetio_impl"] + datasets_impl = datasetio_settings["datasets_impl"] + await register_dataset(datasets_impl) + + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + + assert isinstance(response.rows, list) + assert len(response.rows) == 3 + assert response.next_page_token == 3 + + # iterate over all rows + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=10, + page_token=response.next_page_token, + ) + + assert isinstance(response.rows, list) + assert len(response.rows) == 10 + assert response.next_page_token == 13 diff --git a/llama_stack/providers/utils/datasetio/dataset_utils.py b/llama_stack/providers/utils/datasetio/dataset_utils.py new file mode 100644 index 000000000..960004937 --- /dev/null +++ b/llama_stack/providers/utils/datasetio/dataset_utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from abc import ABC, abstractmethod + + +class BaseDataset(ABC): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError() + + @abstractmethod + def __getitem__(self, idx): + raise NotImplementedError() + + @abstractmethod + def load(self): + raise NotImplementedError()