From f046899a1cf4b35c1f1f4092196b98437cd3e2b2 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 14 Oct 2024 13:16:39 -0700 Subject: [PATCH] datasets api --- .../apis/{dataset => datasets}/__init__.py | 2 +- llama_stack/apis/datasets/client.py | 92 +++++++++++++++++++ .../dataset.py => datasets/datasets.py} | 10 +- llama_stack/apis/evals/evals.py | 2 +- llama_stack/distribution/datatypes.py | 10 ++ llama_stack/distribution/distribution.py | 20 +++- llama_stack/distribution/registry/__init__.py | 17 ++++ .../registry/datasets/__init__.py | 4 +- .../distribution/registry/datasets/dataset.py | 90 ++++++------------ .../registry/datasets/dataset_wrappers.py | 78 ++++++++++++++++ llama_stack/distribution/resolver.py | 23 +++++ llama_stack/providers/datatypes.py | 4 +- .../impls/meta_reference/evals/evals.py | 6 +- .../evals/scorer/basic_scorers.py | 2 +- tests/examples/local-run.yaml | 1 + 15 files changed, 281 insertions(+), 80 deletions(-) rename llama_stack/apis/{dataset => datasets}/__init__.py (82%) create mode 100644 llama_stack/apis/datasets/client.py rename llama_stack/apis/{dataset/dataset.py => datasets/datasets.py} (96%) create mode 100644 llama_stack/distribution/registry/datasets/dataset_wrappers.py diff --git a/llama_stack/apis/dataset/__init__.py b/llama_stack/apis/datasets/__init__.py similarity index 82% rename from llama_stack/apis/dataset/__init__.py rename to llama_stack/apis/datasets/__init__.py index 33557a0ab..102b9927f 100644 --- a/llama_stack/apis/dataset/__init__.py +++ b/llama_stack/apis/datasets/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .dataset import * # noqa: F401 F403 +from .datasets import * # noqa: F401 F403 diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py new file mode 100644 index 000000000..241db6568 --- /dev/null +++ b/llama_stack/apis/datasets/client.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json + +import fire +import httpx + +from .datasets import * # noqa: F403 + + +class DatasetClient(Datasets): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def create_dataset( + self, + dataset_def: DatasetDef, + ) -> None: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/datasets/create", + json={ + "dataset_def": json.loads(dataset_def.json()), + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + return None + + async def get_dataset( + self, + dataset_identifier: str, + ) -> DatasetDef: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/datasets/create", + json={ + "dataset_identifier": dataset_identifier, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + return DatasetDef(**response.json()) + + async def delete_dataset( + self, + dataset_identifier: str, + ) -> DatasetDef: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/datasets/delete", + json={ + "dataset_identifier": dataset_identifier, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + return None + + +async def run_main(host: str, port: int): + client = DatasetClient(f"http://{host}:{port}") + + # Custom Eval Task + response = await client.create_dataset( + dataset_def=CustomDatasetDef( + identifier="test-dataset", + url="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + ), + ) + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/dataset/dataset.py b/llama_stack/apis/datasets/datasets.py similarity index 96% rename from llama_stack/apis/dataset/dataset.py rename to llama_stack/apis/datasets/datasets.py index 798f3aba9..c79301557 100644 --- a/llama_stack/apis/dataset/dataset.py +++ b/llama_stack/apis/datasets/datasets.py @@ -143,19 +143,19 @@ class BaseDataset(ABC, Generic[TDatasetSample]): class Datasets(Protocol): @webmethod(route="/datasets/create") - def create_dataset( + async def create_dataset( self, - dataset: DatasetDef, + dataset_def: DatasetDef, ) -> None: ... @webmethod(route="/datasets/get") - def get_dataset( + async def get_dataset( self, dataset_identifier: str, ) -> DatasetDef: ... @webmethod(route="/datasets/delete") - def delete_dataset( + async def delete_dataset( self, - dataset_uuid: str, + dataset_identifier: str, ) -> None: ... diff --git a/llama_stack/apis/evals/evals.py b/llama_stack/apis/evals/evals.py index a62fa4418..af0b291e8 100644 --- a/llama_stack/apis/evals/evals.py +++ b/llama_stack/apis/evals/evals.py @@ -11,7 +11,7 @@ from llama_models.schema_utils import webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.dataset import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 class EvaluationJob(BaseModel): diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 0044de09e..ce7f5a8e5 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -73,6 +73,16 @@ class RoutingTableProviderSpec(ProviderSpec): pip_packages: List[str] = Field(default_factory=list) +# Example: /datasets +class RegistryProviderSpec(ProviderSpec): + provider_type: str = "registry" + config_class: str = "" + docker_image: Optional[str] = None + + module: str + pip_packages: List[str] = Field(default_factory=list) + + class DistributionSpec(BaseModel): description: Optional[str] = Field( default="", diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 999646cc0..d96db23b4 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -21,6 +21,19 @@ class AutoRoutedApiInfo(BaseModel): router_api: Api +class RegistryApiInfo(BaseModel): + registry_api: Api + # registry: Registry + + +def builtin_registry_apis() -> List[RegistryApiInfo]: + return [ + RegistryApiInfo( + registry_api=Api.datasets, + ) + ] + + def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: return [ AutoRoutedApiInfo( @@ -42,7 +55,12 @@ def providable_apis() -> List[Api]: routing_table_apis = set( x.routing_table_api for x in builtin_automatically_routed_apis() ) - return [api for api in Api if api not in routing_table_apis and api != Api.inspect] + registry_apis = set( + x.registry_api for x in builtin_registry_apis() if x.registry_api + ) + non_providable_apis = routing_table_apis | registry_apis | {Api.inspect} + + return [api for api in Api if api not in non_providable_apis] def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: diff --git a/llama_stack/distribution/registry/__init__.py b/llama_stack/distribution/registry/__init__.py index 756f351d8..6e6833328 100644 --- a/llama_stack/distribution/registry/__init__.py +++ b/llama_stack/distribution/registry/__init__.py @@ -3,3 +3,20 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any + +from llama_stack.providers.datatypes import Api +from .datasets.dataset import DatasetRegistryImpl + + +async def get_registry_impl(api: Api, _deps) -> Any: + api_to_registry = { + "datasets": DatasetRegistryImpl, + } + + if api.value not in api_to_registry: + raise ValueError(f"API {api.value} not found in registry map") + + impl = api_to_registry[api.value]() + await impl.initialize() + return impl diff --git a/llama_stack/distribution/registry/datasets/__init__.py b/llama_stack/distribution/registry/datasets/__init__.py index 68de3fa87..384028b9e 100644 --- a/llama_stack/distribution/registry/datasets/__init__.py +++ b/llama_stack/distribution/registry/datasets/__init__.py @@ -5,9 +5,9 @@ # the root directory of this source tree. # TODO: make these import config based -from llama_stack.apis.dataset import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 from ..registry import Registry -from .dataset import CustomDataset, HuggingfaceDataset +from .dataset_wrappers import CustomDataset, HuggingfaceDataset class DatasetRegistry(Registry[BaseDataset]): diff --git a/llama_stack/distribution/registry/datasets/dataset.py b/llama_stack/distribution/registry/datasets/dataset.py index 0bd86b8d4..936fd0713 100644 --- a/llama_stack/distribution/registry/datasets/dataset.py +++ b/llama_stack/distribution/registry/datasets/dataset.py @@ -3,76 +3,38 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pandas -from datasets import Dataset, load_dataset -from llama_stack.apis.dataset import * # noqa: F403 +# from llama_stack.apis.datasets import * +# from llama_stack.distribution.registry.datasets import DatasetRegistry # noqa: F403 +# from ..registry import Registry +# from .dataset_wrappers import CustomDataset, HuggingfaceDataset -class CustomDataset(BaseDataset[DictSample]): - def __init__(self, config: CustomDatasetDef) -> None: - super().__init__() - self.config = config - self.dataset = None - self.index = 0 +class DatasetRegistryImpl(Datasets): + """API Impl to interact with underlying dataset registry""" - @property - def dataset_id(self) -> str: - return self.config.identifier + def __init__( + self, + ) -> None: + pass - def __iter__(self) -> Iterator[DictSample]: - if not self.dataset: - self.load() - return (DictSample(data=x) for x in self.dataset) + async def initialize(self) -> None: + pass - def __str__(self) -> str: - return f"CustomDataset({self.config})" + async def shutdown(self) -> None: + pass - def __len__(self) -> int: - if not self.dataset: - self.load() - return len(self.dataset) + async def create_dataset( + self, + dataset_def: DatasetDef, + ) -> None: + print(f"Creating dataset {dataset.identifier}") - def load(self, n_samples: Optional[int] = None) -> None: - if self.dataset: - return + async def get_dataset( + self, + dataset_identifier: str, + ) -> DatasetDef: + pass - # TODO: better support w/ data url - if self.config.url.endswith(".csv"): - df = pandas.read_csv(self.config.url) - elif self.config.url.endswith(".xlsx"): - df = pandas.read_excel(self.config.url) - - if n_samples is not None: - df = df.sample(n=n_samples) - - self.dataset = Dataset.from_pandas(df) - - -class HuggingfaceDataset(BaseDataset[DictSample]): - def __init__(self, config: HuggingfaceDatasetDef): - super().__init__() - self.config = config - self.dataset = None - - @property - def dataset_id(self) -> str: - return self.config.identifier - - def __iter__(self) -> Iterator[DictSample]: - if not self.dataset: - self.load() - return (DictSample(data=x) for x in self.dataset) - - def __str__(self): - return f"HuggingfaceDataset({self.config})" - - def __len__(self): - if not self.dataset: - self.load() - return len(self.dataset) - - def load(self): - if self.dataset: - return - self.dataset = load_dataset(self.config.dataset_name, **self.config.kwargs) + async def delete_dataset(self, dataset_identifier: str) -> None: + pass diff --git a/llama_stack/distribution/registry/datasets/dataset_wrappers.py b/llama_stack/distribution/registry/datasets/dataset_wrappers.py new file mode 100644 index 000000000..e18165a11 --- /dev/null +++ b/llama_stack/distribution/registry/datasets/dataset_wrappers.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import pandas +from datasets import Dataset, load_dataset + +from llama_stack.apis.datasets import * # noqa: F403 + + +class CustomDataset(BaseDataset[DictSample]): + def __init__(self, config: CustomDatasetDef) -> None: + super().__init__() + self.config = config + self.dataset = None + self.index = 0 + + @property + def dataset_id(self) -> str: + return self.config.identifier + + def __iter__(self) -> Iterator[DictSample]: + if not self.dataset: + self.load() + return (DictSample(data=x) for x in self.dataset) + + def __str__(self) -> str: + return f"CustomDataset({self.config})" + + def __len__(self) -> int: + if not self.dataset: + self.load() + return len(self.dataset) + + def load(self, n_samples: Optional[int] = None) -> None: + if self.dataset: + return + + # TODO: better support w/ data url + if self.config.url.endswith(".csv"): + df = pandas.read_csv(self.config.url) + elif self.config.url.endswith(".xlsx"): + df = pandas.read_excel(self.config.url) + + if n_samples is not None: + df = df.sample(n=n_samples) + + self.dataset = Dataset.from_pandas(df) + + +class HuggingfaceDataset(BaseDataset[DictSample]): + def __init__(self, config: HuggingfaceDatasetDef): + super().__init__() + self.config = config + self.dataset = None + + @property + def dataset_id(self) -> str: + return self.config.identifier + + def __iter__(self) -> Iterator[DictSample]: + if not self.dataset: + self.load() + return (DictSample(data=x) for x in self.dataset) + + def __str__(self): + return f"HuggingfaceDataset({self.config})" + + def __len__(self): + if not self.dataset: + self.load() + return len(self.dataset) + + def load(self): + if self.dataset: + return + self.dataset = load_dataset(self.config.dataset_name, **self.config.kwargs) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 672a4ea60..e71c3fd8c 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -12,6 +12,7 @@ from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.apis.agents import Agents +from llama_stack.apis.datasets import Datasets from llama_stack.apis.evals import Evals from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect @@ -23,6 +24,7 @@ from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, + builtin_registry_apis, get_provider_registry, ) from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -40,6 +42,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.shields: Shields, Api.telemetry: Telemetry, Api.evals: Evals, + Api.datasets: Datasets, } @@ -139,6 +142,20 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An ) } + for info in builtin_registry_apis(): + providers_with_specs[info.registry_api.value] = { + "__builtin__": ProviderWithSpec( + provider_id="__registry__", + provider_type="__registry__", + config={}, + spec=RegistryProviderSpec( + api=info.registry_api, + module="llama_stack.distribution.registry", + deps__=[], + ), + ) + } + sorted_providers = topological_sort( {k: v.values() for k, v in providers_with_specs.items()} ) @@ -259,6 +276,12 @@ async def instantiate_provider( config = None args = [provider_spec.api, inner_impls, deps] + elif isinstance(provider_spec, RegistryProviderSpec): + print("ROUTER PROVIDER SPEC") + method = "get_registry_impl" + + config = None + args = [provider_spec.api, deps] else: method = "get_provider_impl" diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 50ab0691b..1d397c9e7 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -28,11 +28,13 @@ class Api(Enum): models = "models" shields = "shields" memory_banks = "memory_banks" - evals = "evals" # built-in API inspect = "inspect" + evals = "evals" + datasets = "datasets" + class ModelsProtocolPrivate(Protocol): async def list_models(self) -> List[ModelDef]: ... diff --git a/llama_stack/providers/impls/meta_reference/evals/evals.py b/llama_stack/providers/impls/meta_reference/evals/evals.py index f717fc9d8..3ae988cbd 100644 --- a/llama_stack/providers/impls/meta_reference/evals/evals.py +++ b/llama_stack/providers/impls/meta_reference/evals/evals.py @@ -9,11 +9,9 @@ from termcolor import cprint from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403 -from llama_stack.apis.dataset import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 from .config import MetaReferenceEvalsImplConfig - -# from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry from .tasks.run_eval_task import RunEvalTask @@ -47,7 +45,7 @@ class MetaReferenceEvalsImpl(Evals): eval_task_config = EvaluateTaskConfig( dataset_config=EvaluateDatasetConfig( dataset_name=dataset, - row_limit=2, + row_limit=3, ), generation_config=EvaluateModelGenerationConfig( model=model, diff --git a/llama_stack/providers/impls/meta_reference/evals/scorer/basic_scorers.py b/llama_stack/providers/impls/meta_reference/evals/scorer/basic_scorers.py index ff9639ecd..47d41c6d6 100644 --- a/llama_stack/providers/impls/meta_reference/evals/scorer/basic_scorers.py +++ b/llama_stack/providers/impls/meta_reference/evals/scorer/basic_scorers.py @@ -6,7 +6,7 @@ import random from llama_stack.apis.evals.evals import BaseScorer, EvalResult, SingleEvalResult -from llama_stack.apis.dataset.dataset import * # noqa: F401 F403 +from llama_stack.apis.datasets.datasets import * # noqa: F401 F403 class AggregateScorer(BaseScorer[ScorerInputSample]): diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml index 3c9f73e0b..31fb72670 100644 --- a/tests/examples/local-run.yaml +++ b/tests/examples/local-run.yaml @@ -12,6 +12,7 @@ apis: - inference - safety - evals +- datasets providers: evals: - provider_id: meta-reference