mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
chore(api): add mypy coverage to localfs
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
81109a0f72
commit
beb6d11cd8
2 changed files with 31 additions and 9 deletions
|
@ -12,7 +12,7 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
|
|
||||||
from .config import LocalFSDatasetIOConfig
|
from .config import LocalFSDatasetIOConfig
|
||||||
|
@ -27,11 +27,13 @@ class PandasDataframeDataset:
|
||||||
self.df = None
|
self.df = None
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
if self.df is None:
|
||||||
|
raise RuntimeError("Dataset not loaded. Please call .load() first")
|
||||||
return len(self.df)
|
return len(self.df)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
if self.df is None:
|
||||||
|
raise RuntimeError("Dataset not loaded. Please call .load() first")
|
||||||
if isinstance(idx, slice):
|
if isinstance(idx, slice):
|
||||||
return self.df.iloc[idx].to_dict(orient="records")
|
return self.df.iloc[idx].to_dict(orient="records")
|
||||||
else:
|
else:
|
||||||
|
@ -49,15 +51,27 @@ class PandasDataframeDataset:
|
||||||
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||||
|
|
||||||
if self.df is None:
|
if self.df is None:
|
||||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
if self.dataset_def.source.type == "uri":
|
||||||
|
raise ValueError(f"Failed to load dataset from {self.dataset_def.source.uri}")
|
||||||
|
else:
|
||||||
|
raise ValueError("Failed to load dataset from rows")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetStore:
|
||||||
|
def __init__(self, dataset_infos: dict[str, Dataset]) -> None:
|
||||||
|
self.dataset_infos = dataset_infos
|
||||||
|
|
||||||
|
def get_dataset(self, dataset_id: str) -> Dataset:
|
||||||
|
return self.dataset_infos[dataset_id]
|
||||||
|
|
||||||
|
|
||||||
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
|
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
# local registry for keeping track of datasets within the provider
|
# local registry for keeping track of datasets within the provider
|
||||||
self.dataset_infos = {}
|
self.dataset_infos: dict[str, Dataset] = {}
|
||||||
self.kvstore = None
|
self.kvstore: KVStore | None = None
|
||||||
|
self.dataset_store = DatasetStore(self.dataset_infos)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
@ -66,8 +80,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
end_key = f"{DATASETS_PREFIX}\xff"
|
end_key = f"{DATASETS_PREFIX}\xff"
|
||||||
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
for dataset in stored_datasets:
|
for dataset_json in stored_datasets:
|
||||||
dataset = Dataset.model_validate_json(dataset)
|
dataset = Dataset.model_validate_json(dataset_json)
|
||||||
self.dataset_infos[dataset.identifier] = dataset
|
self.dataset_infos[dataset.identifier] = dataset
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
@ -78,6 +92,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
) -> None:
|
) -> None:
|
||||||
# Store in kvstore
|
# Store in kvstore
|
||||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||||
|
if self.kvstore is None:
|
||||||
|
raise RuntimeError("KVStore not initialized. Please call initialize() first")
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=key,
|
key=key,
|
||||||
value=dataset_def.model_dump_json(),
|
value=dataset_def.model_dump_json(),
|
||||||
|
@ -86,6 +102,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||||
|
if self.kvstore is None:
|
||||||
|
raise RuntimeError("KVStore not initialized. Please call initialize() first")
|
||||||
await self.kvstore.delete(key=key)
|
await self.kvstore.delete(key=key)
|
||||||
del self.dataset_infos[dataset_id]
|
del self.dataset_infos[dataset_id]
|
||||||
|
|
||||||
|
@ -99,6 +117,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||||
await dataset_impl.load()
|
await dataset_impl.load()
|
||||||
|
|
||||||
|
if dataset_impl.df is None:
|
||||||
|
raise RuntimeError("Failed to load dataset dataframe")
|
||||||
records = dataset_impl.df.to_dict("records")
|
records = dataset_impl.df.to_dict("records")
|
||||||
return paginate_records(records, start_index, limit)
|
return paginate_records(records, start_index, limit)
|
||||||
|
|
||||||
|
@ -108,4 +128,6 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
await dataset_impl.load()
|
await dataset_impl.load()
|
||||||
|
|
||||||
new_rows_df = pandas.DataFrame(rows)
|
new_rows_df = pandas.DataFrame(rows)
|
||||||
|
if dataset_impl.df is None:
|
||||||
|
raise RuntimeError("Failed to load dataset dataframe")
|
||||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||||
|
|
|
@ -245,7 +245,7 @@ exclude = [
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||||
"^llama_stack/providers/inline/datasetio/localfs/",
|
"^llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
|
||||||
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/generation\\.py$",
|
"^llama_stack/models/llama/llama3/generation\\.py$",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue