feat: Add warning message for unsupported param

The 2 datasetio providers inline::localfs and remote::huggingface do not
support the filter_condition parameter that is defined for the
get_rows_paginated API. This commit adds a warning message when non
empty filter_condition is passed to this API for these providers.

Signed-off-by: Josh Salomon <jsalomon@redhat.com>
This commit is contained in:
Josh Salomon 2025-03-05 20:27:25 +02:00
parent 99b6925ad8
commit dc2995842d
3 changed files with 38 additions and 1 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
import logging
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
@ -16,11 +17,14 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.common.provider_utils import get_provider_type
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig from .config import LocalFSDatasetIOConfig
log = logging.getLogger(__name__)
DATASETS_PREFIX = "localfs_datasets:" DATASETS_PREFIX = "localfs_datasets:"
@ -141,6 +145,13 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
if page_token and not page_token.isnumeric(): if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token") raise ValueError("Invalid page_token")
if filter_condition is not None and filter_condition.strip():
dataset_type = get_provider_type(self.__module__)
provider_id = dataset_info.dataset_def.provider_id
log.warning(
f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}"
)
if page_token is None or len(page_token) == 0: if page_token is None or len(page_token) == 0:
next_page_token = 0 next_page_token = 0
else: else:
@ -172,7 +183,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url) url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url) parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme: if parsed_url.scheme == "file" or not parsed_url.scheme:

View file

@ -3,6 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import datasets as hf_datasets import datasets as hf_datasets
@ -10,11 +11,14 @@ import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.common.provider_utils import get_provider_type
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig from .config import HuggingfaceDatasetIOConfig
log = logging.getLogger(__name__)
DATASETS_PREFIX = "datasets:" DATASETS_PREFIX = "datasets:"
@ -86,6 +90,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
if page_token and not page_token.isnumeric(): if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token") raise ValueError("Invalid page_token")
if filter_condition is not None and filter_condition.strip():
dataset_type = get_provider_type(self.__module__)
provider_id = dataset_def.provider_id
log.warning(
f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}"
)
if page_token is None or len(page_token) == 0: if page_token is None or len(page_token) == 0:
next_page_token = 0 next_page_token = 0
else: else:

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
def get_provider_type(module: str) -> str:
parts = module.split(".")
if parts[0] != "llama_stack" or parts[1] != "providers":
raise ValueError(f"Invalid module name <{module}>")
if parts[2] == "inline" or parts[2] == "remote":
return parts[2]
else:
raise ValueError(f"Invalid module name <{module}>")