From dc2995842dbed2a2e6fe26b5523dbd15f956a85f Mon Sep 17 00:00:00 2001 From: Josh Salomon Date: Wed, 5 Mar 2025 20:27:25 +0200 Subject: [PATCH] feat: Add warning message for unsupported param The 2 datasetio providers inline::localfs and remote::huggingface do not support the filter_condition parameter that is defined for the get_rows_paginated API. This commit adds a warning message when non empty filter_condition is passed to this API for these providers. Signed-off-by: Josh Salomon --- .../inline/datasetio/localfs/datasetio.py | 13 ++++++++++++- .../remote/datasetio/huggingface/huggingface.py | 11 +++++++++++ .../providers/utils/common/provider_utils.py | 15 +++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 llama_stack/providers/utils/common/provider_utils.py diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 491f03f72..74faef2ef 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import base64 +import logging import os from abc import ABC, abstractmethod from dataclasses import dataclass @@ -16,11 +17,14 @@ from llama_stack.apis.common.content_types import URL from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate +from llama_stack.providers.utils.common.provider_utils import get_provider_type from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.kvstore import kvstore_impl from .config import LocalFSDatasetIOConfig +log = logging.getLogger(__name__) + DATASETS_PREFIX = "localfs_datasets:" @@ -141,6 +145,13 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): if page_token and not page_token.isnumeric(): raise ValueError("Invalid page_token") + if filter_condition is not None and filter_condition.strip(): + dataset_type = get_provider_type(self.__module__) + provider_id = dataset_info.dataset_def.provider_id + log.warning( + f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}" + ) + if page_token is None or len(page_token) == 0: next_page_token = 0 else: @@ -172,7 +183,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) - url = str(dataset_info.dataset_def.url) + url = str(dataset_info.dataset_def.url.uri) parsed_url = urlparse(url) if parsed_url.scheme == "file" or not parsed_url.scheme: diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index cd4e7f1f1..9379ffdbb 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from typing import Any, Dict, List, Optional import datasets as hf_datasets @@ -10,11 +11,14 @@ import datasets as hf_datasets from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate +from llama_stack.providers.utils.common.provider_utils import get_provider_type from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.kvstore import kvstore_impl from .config import HuggingfaceDatasetIOConfig +log = logging.getLogger(__name__) + DATASETS_PREFIX = "datasets:" @@ -86,6 +90,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): if page_token and not page_token.isnumeric(): raise ValueError("Invalid page_token") + if filter_condition is not None and filter_condition.strip(): + dataset_type = get_provider_type(self.__module__) + provider_id = dataset_def.provider_id + log.warning( + f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}" + ) + if page_token is None or len(page_token) == 0: next_page_token = 0 else: diff --git a/llama_stack/providers/utils/common/provider_utils.py b/llama_stack/providers/utils/common/provider_utils.py new file mode 100644 index 000000000..8a4131f51 --- /dev/null +++ b/llama_stack/providers/utils/common/provider_utils.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +def get_provider_type(module: str) -> str: + parts = module.split(".") + if parts[0] != "llama_stack" or parts[1] != "providers": + raise ValueError(f"Invalid module name <{module}>") + if parts[2] == "inline" or parts[2] == "remote": + return parts[2] + else: + raise ValueError(f"Invalid module name <{module}>")