wip huggingface register

This commit is contained in:
Xi Yan 2024-11-07 15:59:55 -08:00
parent d1633dc412
commit 37d87c585a
5 changed files with 82 additions and 16 deletions

View file

@ -7,13 +7,27 @@ from typing import List, Optional
from llama_stack.apis.datasetio import * # noqa: F403
from datasets import load_dataset
from datasets import Dataset, load_dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from .config import HuggingfaceDatasetIOConfig
from .dataset_defs.llamastack_mmlu import llamastack_mmlu
def load_hf_dataset(dataset_def: DatasetDef):
if dataset_def.metadata.get("dataset_path", None):
return load_dataset(**dataset_def.metadata)
df = get_dataframe_from_url(dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = Dataset.from_pandas(df)
return dataset
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: HuggingfaceDatasetIOConfig) -> None:
self.config = config
@ -31,6 +45,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
self,
dataset_def: DatasetDef,
) -> None:
print("registering dataset", dataset_def)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def list_datasets(self) -> List[DatasetDef]:
@ -44,7 +60,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_dataset(**dataset_def.metadata)
loaded_dataset = load_hf_dataset(dataset_def)
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")