From d75095033d41a7d79ecd831702a27ad343809558 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 10:21:25 -0800 Subject: [PATCH] huggingface provider --- .../inline/huggingface/datasetio/__init__.py | 18 ++++ .../inline/huggingface/datasetio/config.py | 9 ++ .../huggingface/datasetio/huggingface.py | 82 +++++++++++++++++++ llama_stack/providers/registry/datasetio.py | 8 ++ 4 files changed, 117 insertions(+) create mode 100644 llama_stack/providers/inline/huggingface/datasetio/__init__.py create mode 100644 llama_stack/providers/inline/huggingface/datasetio/config.py create mode 100644 llama_stack/providers/inline/huggingface/datasetio/huggingface.py diff --git a/llama_stack/providers/inline/huggingface/datasetio/__init__.py b/llama_stack/providers/inline/huggingface/datasetio/__init__.py new file mode 100644 index 000000000..cda67177c --- /dev/null +++ b/llama_stack/providers/inline/huggingface/datasetio/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import HuggingfaceDatasetIOConfig + + +async def get_provider_impl( + config: HuggingfaceDatasetIOConfig, + _deps, +): + from .huggingface import HuggingfaceDatasetIOImpl + + impl = HuggingfaceDatasetIOImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/huggingface/datasetio/config.py b/llama_stack/providers/inline/huggingface/datasetio/config.py new file mode 100644 index 000000000..89dbe53a0 --- /dev/null +++ b/llama_stack/providers/inline/huggingface/datasetio/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.datasetio import * # noqa: F401, F403 + + +class HuggingfaceDatasetIOConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/huggingface/datasetio/huggingface.py b/llama_stack/providers/inline/huggingface/datasetio/huggingface.py new file mode 100644 index 000000000..e17850841 --- /dev/null +++ b/llama_stack/providers/inline/huggingface/datasetio/huggingface.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import List, Optional + +from llama_models.llama3.api.datatypes import * # noqa: F403 + +from llama_stack.apis.datasetio import * # noqa: F403 + +from datasets import load_dataset + +from llama_stack.apis.common.type_system import StringType +from llama_stack.providers.datatypes import DatasetsProtocolPrivate + +from .config import HuggingfaceDatasetIOConfig + + +class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): + def __init__(self, config: HuggingfaceDatasetIOConfig) -> None: + self.config = config + # local registry for keeping track of datasets within the provider + self.dataset_infos = {} + + async def initialize(self) -> None: + # pre-registered benchmark datasets + self.dataset_infos = { + "mmlu": DatasetDef( + identifier="mmlu", + url=URL(uri="https://huggingface.co/datasets/yanxi0830/ls-mmlu"), + dataset_schema={ + "expected_answer": StringType(), + "input_query": StringType(), + "generated_answer": StringType(), + }, + metadata={"path": "yanxi0830/ls-mmlu", "split": "train"}, + ) + } + + async def shutdown(self) -> None: ... + + async def register_dataset( + self, + dataset_def: DatasetDef, + ) -> None: + self.dataset_infos[dataset_def.identifier] = dataset_def + + async def list_datasets(self) -> List[DatasetDef]: + return list(self.dataset_infos.values()) + + async def get_rows_paginated( + self, + dataset_id: str, + rows_in_page: int, + page_token: Optional[str] = None, + filter_condition: Optional[str] = None, + ) -> PaginatedRowsResult: + dataset_def = self.dataset_infos[dataset_id] + loaded_dataset = load_dataset(**dataset_def.metadata) + + if page_token and not page_token.isnumeric(): + raise ValueError("Invalid page_token") + + if page_token is None or len(page_token) == 0: + next_page_token = 0 + else: + next_page_token = int(page_token) + + start = next_page_token + if rows_in_page == -1: + end = len(loaded_dataset) + else: + end = min(start + rows_in_page, len(loaded_dataset)) + + rows = [loaded_dataset[i] for i in range(start, end)] + + return PaginatedRowsResult( + rows=rows, + total_count=len(rows), + next_page_token=str(end), + ) diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 976bbd448..f2c740629 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -19,4 +19,12 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.meta_reference.datasetio.MetaReferenceDatasetIOConfig", api_dependencies=[], ), + InlineProviderSpec( + api=Api.datasetio, + provider_type="huggingface", + pip_packages=["datasets"], + module="llama_stack.providers.inline.huggingface.datasetio", + config_class="llama_stack.providers.inline.huggingface.datasetio.HuggingfaceDatasetIOConfig", + api_dependencies=[], + ), ]