[Evals API][3/n] scoring_functions / scoring meta-reference implementations (#296)

* wip

* dataset validation

* test_scoring

* cleanup

* clean up test

* comments

* error checking

* dataset client

* test client:

* datasetio client

* clean up

* basic scoring function works

* scorer wip

* equality scorer

* score batch impl

* score batch

* update scoring test

* refactor

* validate scorer input

* address comments

* add all rows scores to ScoringResult

* bugfix

* scoring function def rename
This commit is contained in:
Xi Yan 2024-10-24 14:52:30 -07:00 committed by GitHub
parent e70420a06e
commit cb84034567
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 904 additions and 51 deletions

View file

@ -8,8 +8,13 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import base64
import mimetypes
from pathlib import Path
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
@ -41,14 +46,35 @@ async def datasetio_settings():
}
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
async def register_dataset(datasets_impl: Datasets):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
dataset = DatasetDefWithProvider(
identifier="test_dataset",
provider_id=os.environ["PROVIDER_ID"],
url=URL(
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
uri=test_url,
),
columns_schema={},
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
await datasets_impl.register_dataset(dataset)
@ -100,10 +126,10 @@ async def test_get_rows_paginated(datasetio_settings):
# iterate over all rows
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=10,
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 10
assert response.next_page_token == "13"
assert len(response.rows) == 2
assert response.next_page_token == "5"