mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
feat: Add warning message for unsupported param
The 2 datasetio providers inline::localfs and remote::huggingface do not support the filter_condition parameter that is defined for the get_rows_paginated API. This commit adds a warning message when non empty filter_condition is passed to this API for these providers. Signed-off-by: Josh Salomon <jsalomon@redhat.com>
This commit is contained in:
parent
99b6925ad8
commit
dc2995842d
3 changed files with 38 additions and 1 deletions
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
import base64
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -16,11 +17,14 @@ from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.common.provider_utils import get_provider_type
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .config import LocalFSDatasetIOConfig
|
from .config import LocalFSDatasetIOConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
DATASETS_PREFIX = "localfs_datasets:"
|
DATASETS_PREFIX = "localfs_datasets:"
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,6 +145,13 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
if page_token and not page_token.isnumeric():
|
if page_token and not page_token.isnumeric():
|
||||||
raise ValueError("Invalid page_token")
|
raise ValueError("Invalid page_token")
|
||||||
|
|
||||||
|
if filter_condition is not None and filter_condition.strip():
|
||||||
|
dataset_type = get_provider_type(self.__module__)
|
||||||
|
provider_id = dataset_info.dataset_def.provider_id
|
||||||
|
log.warning(
|
||||||
|
f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}"
|
||||||
|
)
|
||||||
|
|
||||||
if page_token is None or len(page_token) == 0:
|
if page_token is None or len(page_token) == 0:
|
||||||
next_page_token = 0
|
next_page_token = 0
|
||||||
else:
|
else:
|
||||||
|
@ -172,7 +183,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||||
|
|
||||||
url = str(dataset_info.dataset_def.url)
|
url = str(dataset_info.dataset_def.url.uri)
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
|
|
||||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import datasets as hf_datasets
|
import datasets as hf_datasets
|
||||||
|
@ -10,11 +11,14 @@ import datasets as hf_datasets
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.common.provider_utils import get_provider_type
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .config import HuggingfaceDatasetIOConfig
|
from .config import HuggingfaceDatasetIOConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
DATASETS_PREFIX = "datasets:"
|
DATASETS_PREFIX = "datasets:"
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,6 +90,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
if page_token and not page_token.isnumeric():
|
if page_token and not page_token.isnumeric():
|
||||||
raise ValueError("Invalid page_token")
|
raise ValueError("Invalid page_token")
|
||||||
|
|
||||||
|
if filter_condition is not None and filter_condition.strip():
|
||||||
|
dataset_type = get_provider_type(self.__module__)
|
||||||
|
provider_id = dataset_def.provider_id
|
||||||
|
log.warning(
|
||||||
|
f"Data filtering is not supported yet for {dataset_type}::{provider_id}, ignoring filter_condition: {filter_condition}"
|
||||||
|
)
|
||||||
|
|
||||||
if page_token is None or len(page_token) == 0:
|
if page_token is None or len(page_token) == 0:
|
||||||
next_page_token = 0
|
next_page_token = 0
|
||||||
else:
|
else:
|
||||||
|
|
15
llama_stack/providers/utils/common/provider_utils.py
Normal file
15
llama_stack/providers/utils/common/provider_utils.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_type(module: str) -> str:
|
||||||
|
parts = module.split(".")
|
||||||
|
if parts[0] != "llama_stack" or parts[1] != "providers":
|
||||||
|
raise ValueError(f"Invalid module name <{module}>")
|
||||||
|
if parts[2] == "inline" or parts[2] == "remote":
|
||||||
|
return parts[2]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid module name <{module}>")
|
Loading…
Add table
Add a link
Reference in a new issue