From beb6d11cd84924f9263f98b3611a30b9a3d8ff4b Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Tue, 8 Jul 2025 22:30:26 +0200 Subject: [PATCH] chore(api): add mypy coverage to localfs Signed-off-by: Mustafa Elbehery --- .../inline/datasetio/localfs/datasetio.py | 38 +++++++++++++++---- pyproject.toml | 2 +- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index da71ecb17..079806440 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -12,7 +12,7 @@ from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from .config import LocalFSDatasetIOConfig @@ -27,11 +27,13 @@ class PandasDataframeDataset: self.df = None def __len__(self) -> int: - assert self.df is not None, "Dataset not loaded. Please call .load() first" + if self.df is None: + raise RuntimeError("Dataset not loaded. Please call .load() first") return len(self.df) def __getitem__(self, idx): - assert self.df is not None, "Dataset not loaded. Please call .load() first" + if self.df is None: + raise RuntimeError("Dataset not loaded. Please call .load() first") if isinstance(idx, slice): return self.df.iloc[idx].to_dict(orient="records") else: @@ -49,15 +51,27 @@ class PandasDataframeDataset: raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}") if self.df is None: - raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") + if self.dataset_def.source.type == "uri": + raise ValueError(f"Failed to load dataset from {self.dataset_def.source.uri}") + else: + raise ValueError("Failed to load dataset from rows") + + +class DatasetStore: + def __init__(self, dataset_infos: dict[str, Dataset]) -> None: + self.dataset_infos = dataset_infos + + def get_dataset(self, dataset_id: str) -> Dataset: + return self.dataset_infos[dataset_id] class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): def __init__(self, config: LocalFSDatasetIOConfig) -> None: self.config = config # local registry for keeping track of datasets within the provider - self.dataset_infos = {} - self.kvstore = None + self.dataset_infos: dict[str, Dataset] = {} + self.kvstore: KVStore | None = None + self.dataset_store = DatasetStore(self.dataset_infos) async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) @@ -66,8 +80,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): end_key = f"{DATASETS_PREFIX}\xff" stored_datasets = await self.kvstore.values_in_range(start_key, end_key) - for dataset in stored_datasets: - dataset = Dataset.model_validate_json(dataset) + for dataset_json in stored_datasets: + dataset = Dataset.model_validate_json(dataset_json) self.dataset_infos[dataset.identifier] = dataset async def shutdown(self) -> None: ... @@ -78,6 +92,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): ) -> None: # Store in kvstore key = f"{DATASETS_PREFIX}{dataset_def.identifier}" + if self.kvstore is None: + raise RuntimeError("KVStore not initialized. Please call initialize() first") await self.kvstore.set( key=key, value=dataset_def.model_dump_json(), @@ -86,6 +102,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def unregister_dataset(self, dataset_id: str) -> None: key = f"{DATASETS_PREFIX}{dataset_id}" + if self.kvstore is None: + raise RuntimeError("KVStore not initialized. Please call initialize() first") await self.kvstore.delete(key=key) del self.dataset_infos[dataset_id] @@ -99,6 +117,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_impl = PandasDataframeDataset(dataset_def) await dataset_impl.load() + if dataset_impl.df is None: + raise RuntimeError("Failed to load dataset dataframe") records = dataset_impl.df.to_dict("records") return paginate_records(records, start_index, limit) @@ -108,4 +128,6 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): await dataset_impl.load() new_rows_df = pandas.DataFrame(rows) + if dataset_impl.df is None: + raise RuntimeError("Failed to load dataset dataframe") dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) diff --git a/pyproject.toml b/pyproject.toml index 30598e5e3..6f78d9aa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -245,7 +245,7 @@ exclude = [ "^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", - "^llama_stack/providers/inline/datasetio/localfs/", + "^llama_stack/providers/inline/agents/meta_reference/safety\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",