diff --git a/llama_stack/apis/evals/client.py b/llama_stack/apis/evals/client.py index b79547713..7d812817b 100644 --- a/llama_stack/apis/evals/client.py +++ b/llama_stack/apis/evals/client.py @@ -12,9 +12,28 @@ import httpx from termcolor import cprint from .evals import * # noqa: F403 +import base64 +import mimetypes +import os + from ..datasets.client import DatasetsClient +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + class EvaluationClient(Evals): def __init__(self, base_url: str): self.base_url = base_url @@ -70,9 +89,8 @@ class EvaluationClient(Evals): return EvaluateResponse(**response.json()) -async def run_main(host: str, port: int): +async def run_main(host: str, port: int, eval_dataset_path: str = ""): client = EvaluationClient(f"http://{host}:{port}") - dataset_client = DatasetsClient(f"http://{host}:{port}") # Full Eval Task @@ -114,10 +132,19 @@ async def run_main(host: str, port: int): ) cprint(response, "cyan") + response = await dataset_client.create_dataset( + dataset_def=CustomDatasetDef( + identifier="rag-evals", + url=data_url_from_file(eval_dataset_path), + ) + ) + cprint(response, "cyan") + # 2. run evals on the registered dataset response = await client.run_scorer( dataset_config=EvaluateDatasetConfig( - dataset_identifier="Llama-3.1-8B-Instruct-evals__mmlu_pro__details", + dataset_identifier="rag-evals", + # dataset_identifier="Llama-3.1-8B-Instruct-evals__mmlu_pro__details", row_limit=10, ), eval_scoring_config=EvaluateScoringConfig( @@ -141,8 +168,8 @@ async def run_main(host: str, port: int): # ) -def main(host: str, port: int): - asyncio.run(run_main(host, port)) +def main(host: str, port: int, eval_dataset_path: str = ""): + asyncio.run(run_main(host, port, eval_dataset_path)) if __name__ == "__main__": diff --git a/llama_stack/distribution/registry/datasets/dataset_wrappers.py b/llama_stack/distribution/registry/datasets/dataset_wrappers.py index 88a487d60..410ad394a 100644 --- a/llama_stack/distribution/registry/datasets/dataset_wrappers.py +++ b/llama_stack/distribution/registry/datasets/dataset_wrappers.py @@ -3,10 +3,13 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import io + import pandas from datasets import Dataset, load_dataset from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.providers.utils.memory.vector_store import parse_data_url class CustomDataset(BaseDataset[DictSample]): @@ -37,11 +40,31 @@ class CustomDataset(BaseDataset[DictSample]): if self.dataset: return - # TODO: better support w/ data url + # TODO: more robust support w/ data url if self.config.url.endswith(".csv"): df = pandas.read_csv(self.config.url) elif self.config.url.endswith(".xlsx"): df = pandas.read_excel(self.config.url) + elif self.config.url.startswith("data:"): + parts = parse_data_url(self.config.url) + data = parts["data"] + if parts["is_base64"]: + data = base64.b64decode(data) + else: + data = unquote(data) + encoding = parts["encoding"] or "utf-8" + data = data.encode(encoding) + + mime_type = parts["mimetype"] + mime_category = mime_type.split("/")[0] + data_bytes = io.BytesIO(data) + + if mime_category == "text": + df = pandas.read_csv(data_bytes) + else: + df = pandas.read_excel(data_bytes) + else: + raise ValueError(f"Unsupported file type: {self.config.url}") if n_samples is not None: df = df.sample(n=n_samples) diff --git a/llama_stack/providers/impls/meta_reference/evals/tasks/run_scoring_task.py b/llama_stack/providers/impls/meta_reference/evals/tasks/run_scoring_task.py index 9e4821a73..9ff6cde4d 100644 --- a/llama_stack/providers/impls/meta_reference/evals/tasks/run_scoring_task.py +++ b/llama_stack/providers/impls/meta_reference/evals/tasks/run_scoring_task.py @@ -11,7 +11,6 @@ from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers impor from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 -from termcolor import cprint class RunScoringTask(BaseTask): @@ -62,9 +61,8 @@ class RunScoringTask(BaseTask): dataset.load(n_samples=dataset_config.row_limit) print(f"Running on {len(dataset)} samples") - # transform dataset into + # transform dataset into List[ScorerInputSample] postprocessed = self.transform_score_input_sample(dataset) - cprint(postprocessed, "blue") # F3 - scorer scorer_config_list = eval_scoring_config.scorer_config_list diff --git a/llama_stack/providers/registry/evals.py b/llama_stack/providers/registry/evals.py index 6ea4c16f5..a8a7e735f 100644 --- a/llama_stack/providers/registry/evals.py +++ b/llama_stack/providers/registry/evals.py @@ -22,6 +22,7 @@ def available_providers() -> List[ProviderSpec]: "datasets", "numpy", "autoevals", + "openpyxl", ], module="llama_stack.providers.impls.meta_reference.evals", config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig",