dataset accept file uploads

This commit is contained in:
Xi Yan 2024-10-14 23:36:15 -07:00
parent 3c29108b6e
commit ec6c63ba57
4 changed files with 58 additions and 9 deletions

View file

@ -12,9 +12,28 @@ import httpx
from termcolor import cprint from termcolor import cprint
from .evals import * # noqa: F403 from .evals import * # noqa: F403
import base64
import mimetypes
import os
from ..datasets.client import DatasetsClient from ..datasets.client import DatasetsClient
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
class EvaluationClient(Evals): class EvaluationClient(Evals):
def __init__(self, base_url: str): def __init__(self, base_url: str):
self.base_url = base_url self.base_url = base_url
@ -70,9 +89,8 @@ class EvaluationClient(Evals):
return EvaluateResponse(**response.json()) return EvaluateResponse(**response.json())
async def run_main(host: str, port: int): async def run_main(host: str, port: int, eval_dataset_path: str = ""):
client = EvaluationClient(f"http://{host}:{port}") client = EvaluationClient(f"http://{host}:{port}")
dataset_client = DatasetsClient(f"http://{host}:{port}") dataset_client = DatasetsClient(f"http://{host}:{port}")
# Full Eval Task # Full Eval Task
@ -114,10 +132,19 @@ async def run_main(host: str, port: int):
) )
cprint(response, "cyan") cprint(response, "cyan")
response = await dataset_client.create_dataset(
dataset_def=CustomDatasetDef(
identifier="rag-evals",
url=data_url_from_file(eval_dataset_path),
)
)
cprint(response, "cyan")
# 2. run evals on the registered dataset # 2. run evals on the registered dataset
response = await client.run_scorer( response = await client.run_scorer(
dataset_config=EvaluateDatasetConfig( dataset_config=EvaluateDatasetConfig(
dataset_identifier="Llama-3.1-8B-Instruct-evals__mmlu_pro__details", dataset_identifier="rag-evals",
# dataset_identifier="Llama-3.1-8B-Instruct-evals__mmlu_pro__details",
row_limit=10, row_limit=10,
), ),
eval_scoring_config=EvaluateScoringConfig( eval_scoring_config=EvaluateScoringConfig(
@ -141,8 +168,8 @@ async def run_main(host: str, port: int):
# ) # )
def main(host: str, port: int): def main(host: str, port: int, eval_dataset_path: str = ""):
asyncio.run(run_main(host, port)) asyncio.run(run_main(host, port, eval_dataset_path))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -3,10 +3,13 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import io
import pandas import pandas
from datasets import Dataset, load_dataset from datasets import Dataset, load_dataset
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import parse_data_url
class CustomDataset(BaseDataset[DictSample]): class CustomDataset(BaseDataset[DictSample]):
@ -37,11 +40,31 @@ class CustomDataset(BaseDataset[DictSample]):
if self.dataset: if self.dataset:
return return
# TODO: better support w/ data url # TODO: more robust support w/ data url
if self.config.url.endswith(".csv"): if self.config.url.endswith(".csv"):
df = pandas.read_csv(self.config.url) df = pandas.read_csv(self.config.url)
elif self.config.url.endswith(".xlsx"): elif self.config.url.endswith(".xlsx"):
df = pandas.read_excel(self.config.url) df = pandas.read_excel(self.config.url)
elif self.config.url.startswith("data:"):
parts = parse_data_url(self.config.url)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
data_bytes = io.BytesIO(data)
if mime_category == "text":
df = pandas.read_csv(data_bytes)
else:
df = pandas.read_excel(data_bytes)
else:
raise ValueError(f"Unsupported file type: {self.config.url}")
if n_samples is not None: if n_samples is not None:
df = df.sample(n=n_samples) df = df.sample(n=n_samples)

View file

@ -11,7 +11,6 @@ from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers impor
from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from termcolor import cprint
class RunScoringTask(BaseTask): class RunScoringTask(BaseTask):
@ -62,9 +61,8 @@ class RunScoringTask(BaseTask):
dataset.load(n_samples=dataset_config.row_limit) dataset.load(n_samples=dataset_config.row_limit)
print(f"Running on {len(dataset)} samples") print(f"Running on {len(dataset)} samples")
# transform dataset into # transform dataset into List[ScorerInputSample]
postprocessed = self.transform_score_input_sample(dataset) postprocessed = self.transform_score_input_sample(dataset)
cprint(postprocessed, "blue")
# F3 - scorer # F3 - scorer
scorer_config_list = eval_scoring_config.scorer_config_list scorer_config_list = eval_scoring_config.scorer_config_list

View file

@ -22,6 +22,7 @@ def available_providers() -> List[ProviderSpec]:
"datasets", "datasets",
"numpy", "numpy",
"autoevals", "autoevals",
"openpyxl",
], ],
module="llama_stack.providers.impls.meta_reference.evals", module="llama_stack.providers.impls.meta_reference.evals",
config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig", config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig",