mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 04:22:35 +00:00
dataset datasetio
This commit is contained in:
parent
e8de70fdbe
commit
f8d9e4f60f
8 changed files with 249 additions and 10 deletions
|
|
@ -5,17 +5,83 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from dataclasses import dataclass
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.dataset_utils import BaseDataset
|
||||
|
||||
from .config import MetaReferenceDatasetIOConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
dataset_def: DatasetDef
|
||||
dataset_impl: BaseDataset
|
||||
next_page_token: Optional[int] = None
|
||||
|
||||
|
||||
class CustomDataset(BaseDataset):
|
||||
def __init__(self, dataset_def: DatasetDef, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_def = dataset_def
|
||||
# TODO: validate dataset_def against schema
|
||||
self.df = None
|
||||
self.load()
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.df is None:
|
||||
self.load()
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
return self.df.iloc[idx].to_dict(orient="records")
|
||||
else:
|
||||
return self.df.iloc[idx].to_dict()
|
||||
|
||||
def load(self) -> None:
|
||||
if self.df:
|
||||
return
|
||||
|
||||
# TODO: more robust support w/ data url
|
||||
if self.dataset_def.url.uri.endswith(".csv"):
|
||||
df = pandas.read_csv(self.dataset_def.url.uri)
|
||||
elif self.dataset_def.url.uri.endswith(".xlsx"):
|
||||
df = pandas.read_excel(self.dataset_def.url.uri)
|
||||
elif self.dataset_def.url.uri.startswith("data:"):
|
||||
parts = parse_data_url(self.dataset_def.url.uri)
|
||||
data = parts["data"]
|
||||
if parts["is_base64"]:
|
||||
data = base64.b64decode(data)
|
||||
else:
|
||||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
|
||||
mime_type = parts["mimetype"]
|
||||
mime_category = mime_type.split("/")[0]
|
||||
data_bytes = io.BytesIO(data)
|
||||
|
||||
if mime_category == "text":
|
||||
df = pandas.read_csv(data_bytes)
|
||||
else:
|
||||
df = pandas.read_excel(data_bytes)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
|
||||
|
||||
self.df = df
|
||||
|
||||
|
||||
class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceDatasetIOConfig) -> None:
|
||||
self.config = config
|
||||
# local registry for keeping track of datasets within the provider
|
||||
self.dataset_infos = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
|
|
@ -23,21 +89,38 @@ class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
|
||||
async def register_dataset(
|
||||
self,
|
||||
memory_bank: DatasetDef,
|
||||
dataset_def: DatasetDef,
|
||||
) -> None:
|
||||
print("register dataset")
|
||||
self.dataset_infos[dataset_def.identifier] = DatasetInfo(
|
||||
dataset_def=dataset_def,
|
||||
dataset_impl=CustomDataset(dataset_def),
|
||||
next_page_token=0,
|
||||
)
|
||||
|
||||
async def list_datasets(self) -> List[DatasetDef]:
|
||||
print("list datasets")
|
||||
return []
|
||||
return [i.dataset_def for i in self.dataset_infos.values()]
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
page_token: Optional[int] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
print("get rows paginated")
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
if page_token is None:
|
||||
dataset_info.next_page_token = 0
|
||||
|
||||
return PaginatedRowsResult(rows=[], total_count=1, next_page_token=None)
|
||||
if rows_in_page == -1:
|
||||
rows = dataset_info.dataset_impl[dataset_info.next_page_token :]
|
||||
|
||||
start = dataset_info.next_page_token
|
||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
||||
rows = dataset_info.dataset_impl[start:end]
|
||||
dataset_info.next_page_token = end
|
||||
|
||||
return PaginatedRowsResult(
|
||||
rows=rows,
|
||||
total_count=len(rows),
|
||||
next_page_token=dataset_info.next_page_token,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue