feat: Add warning message for unsupported param

The 2 datasetio providers inline::localfs and remote::huggingface do not
support the filter_condition parameter that is defined for the
get_rows_paginated API. This commit adds a warning message when non
empty filter_condition is passed to this API for these providers.

Signed-off-by: Josh Salomon <jsalomon@redhat.com>
This commit is contained in:
Josh Salomon 2025-03-05 20:27:25 +02:00
parent 99b6925ad8
commit dc2995842d
3 changed files with 38 additions and 1 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
@ -16,11 +17,14 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.common.provider_utils import get_provider_type
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig
log = logging.getLogger(__name__)
DATASETS_PREFIX = "localfs_datasets:"
@ -141,6 +145,13 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
if filter_condition is not None and filter_condition.strip():
dataset_type = get_provider_type(self.__module__)
provider_id = dataset_info.dataset_def.provider_id
log.warning(
f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}"
)
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
@ -172,7 +183,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url)
url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List, Optional
import datasets as hf_datasets
@ -10,11 +11,14 @@ import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.common.provider_utils import get_provider_type
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig
log = logging.getLogger(__name__)
DATASETS_PREFIX = "datasets:"
@ -86,6 +90,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
if filter_condition is not None and filter_condition.strip():
dataset_type = get_provider_type(self.__module__)
provider_id = dataset_def.provider_id
log.warning(
f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}"
)
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
def get_provider_type(module: str) -> str:
parts = module.split(".")
if parts[0] != "llama_stack" or parts[1] != "providers":
raise ValueError(f"Invalid module name <{module}>")
if parts[2] == "inline" or parts[2] == "remote":
return parts[2]
else:
raise ValueError(f"Invalid module name <{module}>")