From 9c501d042b0ca1f2a4bfd2848c1609af8bb46cb1 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 14 Oct 2024 14:19:15 -0700 Subject: [PATCH] cleanup hardcoded dataset registry --- llama_stack/apis/datasets/client.py | 14 +++++++++-- llama_stack/apis/evals/client.py | 23 +++++++++++++++---- .../registry/datasets/__init__.py | 23 ------------------- 3 files changed, 31 insertions(+), 29 deletions(-) diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py index 476a5964a..e292b14d8 100644 --- a/llama_stack/apis/datasets/client.py +++ b/llama_stack/apis/datasets/client.py @@ -26,7 +26,7 @@ def deserialize_dataset_def(j: Optional[Dict[str, Any]]) -> Optional[DatasetDef] raise ValueError(f"Unknown dataset type: {j['type']}") -class DatasetClient(Datasets): +class DatasetsClient(Datasets): def __init__(self, base_url: str): self.base_url = base_url @@ -104,7 +104,7 @@ class DatasetClient(Datasets): async def run_main(host: str, port: int): - client = DatasetClient(f"http://{host}:{port}") + client = DatasetsClient(f"http://{host}:{port}") # register dataset response = await client.create_dataset( @@ -115,6 +115,16 @@ async def run_main(host: str, port: int): ) cprint(response, "green") + # register HF dataset + response = await client.create_dataset( + dataset_def=HuggingfaceDatasetDef( + identifier="hellaswag", + dataset_name="hellaswag", + kwargs={"split": "validation", "trust_remote_code": True}, + ) + ) + cprint(response, "green") + # get dataset get_dataset = await client.get_dataset( dataset_identifier="test-dataset", diff --git a/llama_stack/apis/evals/client.py b/llama_stack/apis/evals/client.py index b4d1c39fe..d61de8c39 100644 --- a/llama_stack/apis/evals/client.py +++ b/llama_stack/apis/evals/client.py @@ -12,6 +12,7 @@ import httpx from termcolor import cprint from .evals import * # noqa: F403 +from ..datasets.client import DatasetsClient class EvaluationClient(Evals): @@ -54,13 +55,31 @@ class EvaluationClient(Evals): async def run_main(host: str, port: int): client = EvaluationClient(f"http://{host}:{port}") + dataset_client = DatasetsClient(f"http://{host}:{port}") + # Custom Eval Task + + # 1. register custom dataset + response = await dataset_client.create_dataset( + dataset_def=CustomDatasetDef( + identifier="mmlu-simple-eval-en", + url="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + ), + ) + cprint(f"datasets/create: {response}", "cyan") + + # 2. run evals on the registered dataset response = await client.run_evals( model="Llama3.1-8B-Instruct", dataset="mmlu-simple-eval-en", task="mmlu", ) + if response.formatted_report: + cprint(response.formatted_report, "green") + else: + cprint(f"Response: {response}", "green") + # Eleuther Eval Task # response = await client.run_evals( # model="Llama3.1-8B-Instruct", @@ -70,10 +89,6 @@ async def run_main(host: str, port: int): # n_samples=2, # ), # ) - if response.formatted_report: - cprint(response.formatted_report, "green") - else: - cprint(f"Response: {response}", "green") def main(host: str, port: int): diff --git a/llama_stack/distribution/registry/datasets/__init__.py b/llama_stack/distribution/registry/datasets/__init__.py index 384028b9e..816475812 100644 --- a/llama_stack/distribution/registry/datasets/__init__.py +++ b/llama_stack/distribution/registry/datasets/__init__.py @@ -3,32 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -# TODO: make these import config based from llama_stack.apis.datasets import * # noqa: F403 from ..registry import Registry -from .dataset_wrappers import CustomDataset, HuggingfaceDataset class DatasetRegistry(Registry[BaseDataset]): _REGISTRY: Dict[str, BaseDataset] = {} - - -DATASETS_REGISTRY = [ - CustomDataset( - config=CustomDatasetDef( - identifier="mmlu-simple-eval-en", - url="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", - ) - ), - HuggingfaceDataset( - config=HuggingfaceDatasetDef( - identifier="hellaswag", - dataset_name="hellaswag", - kwargs={"split": "validation", "trust_remote_code": True}, - ) - ), -] - -for d in DATASETS_REGISTRY: - DatasetRegistry.register(d.dataset_id, d)