dataset accept file uploads

This commit is contained in:
Xi Yan 2024-10-14 23:36:15 -07:00
parent 3c29108b6e
commit ec6c63ba57
4 changed files with 58 additions and 9 deletions

View file

@ -12,9 +12,28 @@ import httpx
from termcolor import cprint
from .evals import * # noqa: F403
import base64
import mimetypes
import os
from ..datasets.client import DatasetsClient
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
class EvaluationClient(Evals):
def __init__(self, base_url: str):
self.base_url = base_url
@ -70,9 +89,8 @@ class EvaluationClient(Evals):
return EvaluateResponse(**response.json())
async def run_main(host: str, port: int):
async def run_main(host: str, port: int, eval_dataset_path: str = ""):
client = EvaluationClient(f"http://{host}:{port}")
dataset_client = DatasetsClient(f"http://{host}:{port}")
# Full Eval Task
@ -114,10 +132,19 @@ async def run_main(host: str, port: int):
)
cprint(response, "cyan")
response = await dataset_client.create_dataset(
dataset_def=CustomDatasetDef(
identifier="rag-evals",
url=data_url_from_file(eval_dataset_path),
)
)
cprint(response, "cyan")
# 2. run evals on the registered dataset
response = await client.run_scorer(
dataset_config=EvaluateDatasetConfig(
dataset_identifier="Llama-3.1-8B-Instruct-evals__mmlu_pro__details",
dataset_identifier="rag-evals",
# dataset_identifier="Llama-3.1-8B-Instruct-evals__mmlu_pro__details",
row_limit=10,
),
eval_scoring_config=EvaluateScoringConfig(
@ -141,8 +168,8 @@ async def run_main(host: str, port: int):
# )
def main(host: str, port: int):
asyncio.run(run_main(host, port))
def main(host: str, port: int, eval_dataset_path: str = ""):
asyncio.run(run_main(host, port, eval_dataset_path))
if __name__ == "__main__":

View file

@ -3,10 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import pandas
from datasets import Dataset, load_dataset
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import parse_data_url
class CustomDataset(BaseDataset[DictSample]):
@ -37,11 +40,31 @@ class CustomDataset(BaseDataset[DictSample]):
if self.dataset:
return
# TODO: better support w/ data url
# TODO: more robust support w/ data url
if self.config.url.endswith(".csv"):
df = pandas.read_csv(self.config.url)
elif self.config.url.endswith(".xlsx"):
df = pandas.read_excel(self.config.url)
elif self.config.url.startswith("data:"):
parts = parse_data_url(self.config.url)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
data_bytes = io.BytesIO(data)
if mime_category == "text":
df = pandas.read_csv(data_bytes)
else:
df = pandas.read_excel(data_bytes)
else:
raise ValueError(f"Unsupported file type: {self.config.url}")
if n_samples is not None:
df = df.sample(n=n_samples)

View file

@ -11,7 +11,6 @@ from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers impor
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from termcolor import cprint
class RunScoringTask(BaseTask):
@ -62,9 +61,8 @@ class RunScoringTask(BaseTask):
dataset.load(n_samples=dataset_config.row_limit)
print(f"Running on {len(dataset)} samples")
# transform dataset into
# transform dataset into List[ScorerInputSample]
postprocessed = self.transform_score_input_sample(dataset)
cprint(postprocessed, "blue")
# F3 - scorer
scorer_config_list = eval_scoring_config.scorer_config_list

View file

@ -22,6 +22,7 @@ def available_providers() -> List[ProviderSpec]:
"datasets",
"numpy",
"autoevals",
"openpyxl",
],
module="llama_stack.providers.impls.meta_reference.evals",
config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig",