[Evals API][3/n] scoring_functions / scoring meta-reference implementations (#296)

* wip

* dataset validation

* test_scoring

* cleanup

* clean up test

* comments

* error checking

* dataset client

* test client:

* datasetio client

* clean up

* basic scoring function works

* scorer wip

* equality scorer

* score batch impl

* score batch

* update scoring test

* refactor

* validate scorer input

* address comments

* add all rows scores to ScoringResult

* bugfix

* scoring function def rename
This commit is contained in:
Xi Yan 2024-10-24 14:52:30 -07:00 committed by GitHub
parent e70420a06e
commit cb84034567
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 904 additions and 51 deletions

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import os
from pathlib import Path
from typing import Optional
import fire
import httpx
from termcolor import cprint
from .datasets import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class DatasetsClient(Datasets):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_dataset(
self,
dataset_def: DatasetDefWithProvider,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/datasets/register",
json={
"dataset_def": json.loads(dataset_def.json()),
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
return
async def get_dataset(
self,
dataset_identifier: str,
) -> Optional[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/get",
params={
"dataset_identifier": dataset_identifier,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return DatasetDefWithProvider(**response.json())
async def list_datasets(self) -> List[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/list",
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return [DatasetDefWithProvider(**x) for x in response.json()]
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)