mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
feat: Add warning message for unsupported param
The 2 datasetio providers inline::localfs and remote::huggingface do not support the filter_condition parameter that is defined for the get_rows_paginated API. This commit adds a warning message when non empty filter_condition is passed to this API for these providers. Signed-off-by: Josh Salomon <jsalomon@redhat.com>
This commit is contained in:
parent
99b6925ad8
commit
dc2995842d
3 changed files with 38 additions and 1 deletions
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
@ -16,11 +17,14 @@ from llama_stack.apis.common.content_types import URL
|
|||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.provider_utils import get_provider_type
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DATASETS_PREFIX = "localfs_datasets:"
|
||||
|
||||
|
||||
|
@ -141,6 +145,13 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
|
||||
if filter_condition is not None and filter_condition.strip():
|
||||
dataset_type = get_provider_type(self.__module__)
|
||||
provider_id = dataset_info.dataset_def.provider_id
|
||||
log.warning(
|
||||
f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}"
|
||||
)
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
else:
|
||||
|
@ -172,7 +183,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
|
||||
url = str(dataset_info.dataset_def.url)
|
||||
url = str(dataset_info.dataset_def.url.uri)
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import datasets as hf_datasets
|
||||
|
@ -10,11 +11,14 @@ import datasets as hf_datasets
|
|||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.provider_utils import get_provider_type
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import HuggingfaceDatasetIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DATASETS_PREFIX = "datasets:"
|
||||
|
||||
|
||||
|
@ -86,6 +90,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
|
||||
if filter_condition is not None and filter_condition.strip():
|
||||
dataset_type = get_provider_type(self.__module__)
|
||||
provider_id = dataset_def.provider_id
|
||||
log.warning(
|
||||
f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}"
|
||||
)
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
else:
|
||||
|
|
15
llama_stack/providers/utils/common/provider_utils.py
Normal file
15
llama_stack/providers/utils/common/provider_utils.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
def get_provider_type(module: str) -> str:
|
||||
parts = module.split(".")
|
||||
if parts[0] != "llama_stack" or parts[1] != "providers":
|
||||
raise ValueError(f"Invalid module name <{module}>")
|
||||
if parts[2] == "inline" or parts[2] == "remote":
|
||||
return parts[2]
|
||||
else:
|
||||
raise ValueError(f"Invalid module name <{module}>")
|
Loading…
Add table
Add a link
Reference in a new issue