diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py index 241db6568..476a5964a 100644 --- a/llama_stack/apis/datasets/client.py +++ b/llama_stack/apis/datasets/client.py @@ -6,13 +6,26 @@ import asyncio import json +from typing import Optional import fire import httpx +from termcolor import cprint from .datasets import * # noqa: F403 +def deserialize_dataset_def(j: Optional[Dict[str, Any]]) -> Optional[DatasetDef]: + if not j: + return None + if j["type"] == "huggingface": + return HuggingfaceDatasetDef(**j) + elif j["type"] == "custom": + return CustomDatasetDef(**j) + else: + raise ValueError(f"Unknown dataset type: {j['type']}") + + class DatasetClient(Datasets): def __init__(self, base_url: str): self.base_url = base_url @@ -26,7 +39,7 @@ class DatasetClient(Datasets): async def create_dataset( self, dataset_def: DatasetDef, - ) -> None: + ) -> CreateDatasetResponse: async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/datasets/create", @@ -37,28 +50,31 @@ class DatasetClient(Datasets): timeout=60, ) response.raise_for_status() - return None + return CreateDatasetResponse(**response.json()) async def get_dataset( self, dataset_identifier: str, - ) -> DatasetDef: + ) -> Optional[DatasetDef]: async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.base_url}/datasets/create", - json={ + response = await client.get( + f"{self.base_url}/datasets/get", + params={ "dataset_identifier": dataset_identifier, }, headers={"Content-Type": "application/json"}, timeout=60, ) response.raise_for_status() - return DatasetDef(**response.json()) + if not response.json(): + return + + return deserialize_dataset_def(response.json()) async def delete_dataset( self, dataset_identifier: str, - ) -> DatasetDef: + ) -> DeleteDatasetResponse: async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/datasets/delete", @@ -69,19 +85,57 @@ class DatasetClient(Datasets): timeout=60, ) response.raise_for_status() - return None + return DeleteDatasetResponse(**response.json()) + + async def list_dataset( + self, + ) -> List[DatasetDef]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasets/list", + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return [deserialize_dataset_def(x) for x in response.json()] async def run_main(host: str, port: int): client = DatasetClient(f"http://{host}:{port}") - # Custom Eval Task + # register dataset response = await client.create_dataset( dataset_def=CustomDatasetDef( identifier="test-dataset", url="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", ), ) + cprint(response, "green") + + # get dataset + get_dataset = await client.get_dataset( + dataset_identifier="test-dataset", + ) + cprint(get_dataset, "cyan") + + # delete dataset + delete_dataset = await client.delete_dataset( + dataset_identifier="test-dataset", + ) + cprint(delete_dataset, "red") + + # get again after deletion + get_dataset = await client.get_dataset( + dataset_identifier="test-dataset", + ) + cprint(get_dataset, "yellow") + + # list datasets + list_dataset = await client.list_dataset() + cprint(list_dataset, "blue") def main(host: str, port: int): diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index c79301557..11a3f6096 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -115,6 +115,27 @@ DatasetDef = Annotated[ ] +class DatasetsResponseStatus(Enum): + success = "success" + fail = "fail" + + +@json_schema_type +class CreateDatasetResponse(BaseModel): + status: DatasetsResponseStatus = Field( + description="Return status of the dataset creation", + ) + msg: Optional[str] = None + + +@json_schema_type +class DeleteDatasetResponse(BaseModel): + status: DatasetsResponseStatus = Field( + description="Return status of the dataset creation", + ) + msg: Optional[str] = None + + class BaseDataset(ABC, Generic[TDatasetSample]): def __init__(self) -> None: self.type: str = self.__class__.__name__ @@ -146,16 +167,19 @@ class Datasets(Protocol): async def create_dataset( self, dataset_def: DatasetDef, - ) -> None: ... + ) -> CreateDatasetResponse: ... - @webmethod(route="/datasets/get") + @webmethod(route="/datasets/get", method="GET") async def get_dataset( self, dataset_identifier: str, - ) -> DatasetDef: ... + ) -> Optional[DatasetDef]: ... @webmethod(route="/datasets/delete") async def delete_dataset( self, dataset_identifier: str, - ) -> None: ... + ) -> DeleteDatasetResponse: ... + + @webmethod(route="/datasets/list", method="GET") + async def list_datasets(self) -> List[DatasetDef]: ... diff --git a/llama_stack/distribution/registry/datasets/dataset.py b/llama_stack/distribution/registry/datasets/dataset.py index 936fd0713..838e8c65f 100644 --- a/llama_stack/distribution/registry/datasets/dataset.py +++ b/llama_stack/distribution/registry/datasets/dataset.py @@ -4,10 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# from llama_stack.apis.datasets import * -# from llama_stack.distribution.registry.datasets import DatasetRegistry # noqa: F403 -# from ..registry import Registry -# from .dataset_wrappers import CustomDataset, HuggingfaceDataset +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.distribution.registry.datasets import DatasetRegistry +from llama_stack.distribution.registry.datasets.dataset_wrappers import ( + CustomDataset, + HuggingfaceDataset, +) class DatasetRegistryImpl(Datasets): @@ -27,14 +29,55 @@ class DatasetRegistryImpl(Datasets): async def create_dataset( self, dataset_def: DatasetDef, - ) -> None: - print(f"Creating dataset {dataset.identifier}") + ) -> CreateDatasetResponse: + if dataset_def.type == DatasetType.huggingface.value: + dataset_cls = HuggingfaceDataset(dataset_def) + else: + dataset_cls = CustomDataset(dataset_def) + + try: + DatasetRegistry.register( + dataset_def.identifier, + dataset_cls, + ) + except ValueError as e: + return CreateDatasetResponse( + status=DatasetsResponseStatus.fail, + msg=str(e), + ) + + return CreateDatasetResponse( + status=DatasetsResponseStatus.success, + msg=f"Dataset '{dataset_def.identifier}' registered", + ) async def get_dataset( self, dataset_identifier: str, - ) -> DatasetDef: - pass + ) -> Optional[DatasetDef]: + try: + dataset_ref = DatasetRegistry.get(dataset_identifier).config + except ValueError as e: + return None - async def delete_dataset(self, dataset_identifier: str) -> None: - pass + return dataset_ref + + async def delete_dataset(self, dataset_identifier: str) -> DeleteDatasetResponse: + try: + DatasetRegistry.delete(dataset_identifier) + except ValueError as e: + return DeleteDatasetResponse( + status=DatasetsResponseStatus.fail, + msg=str(e), + ) + + return DeleteDatasetResponse( + status=DatasetsResponseStatus.success, + msg=f"Dataset '{dataset_identifier}' deleted", + ) + + async def list_datasets(self) -> List[DatasetDef]: + return [ + DatasetRegistry.get(dataset_identifier).config + for dataset_identifier in DatasetRegistry.names() + ] diff --git a/llama_stack/distribution/registry/registry.py b/llama_stack/distribution/registry/registry.py index b4a5b626d..313fb6d4e 100644 --- a/llama_stack/distribution/registry/registry.py +++ b/llama_stack/distribution/registry/registry.py @@ -27,6 +27,12 @@ class Registry(Generic[TRegistry]): raise ValueError(f"Dataset {name} not found.") return Registry._REGISTRY[name] + @staticmethod + def delete(name: str) -> None: + if name not in Registry._REGISTRY: + raise ValueError(f"Dataset {name} not found.") + del Registry._REGISTRY[name] + @staticmethod def reset() -> None: Registry._REGISTRY = {}