This commit is contained in:
Xi Yan 2025-03-15 17:16:38 -07:00
parent b561cfd902
commit 659f5e86ee
5 changed files with 1094 additions and 1108 deletions

View file

@ -16,15 +16,17 @@ from ..datasetio.test_datasetio import register_dataset
@pytest.mark.parametrize("scoring_fn_id", ["basic::equality"])
def test_evaluate_rows(llama_stack_client, text_model_id, scoring_fn_id):
register_dataset(llama_stack_client, for_generation=True, dataset_id="test_dataset_for_eval")
register_dataset(
llama_stack_client, for_generation=True, dataset_id="test_dataset_for_eval"
)
response = llama_stack_client.datasets.list()
assert any(x.identifier == "test_dataset_for_eval" for x in response)
rows = llama_stack_client.datasetio.get_rows_paginated(
rows = llama_stack_client.datasets.iterrows(
dataset_id="test_dataset_for_eval",
rows_in_page=3,
limit=3,
)
assert len(rows.rows) == 3
assert len(rows.data) == 3
scoring_functions = [
scoring_fn_id,
@ -40,7 +42,7 @@ def test_evaluate_rows(llama_stack_client, text_model_id, scoring_fn_id):
response = llama_stack_client.eval.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=rows.rows,
input_rows=rows.data,
scoring_functions=scoring_functions,
benchmark_config={
"eval_candidate": {
@ -59,7 +61,9 @@ def test_evaluate_rows(llama_stack_client, text_model_id, scoring_fn_id):
@pytest.mark.parametrize("scoring_fn_id", ["basic::subset_of"])
def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
register_dataset(llama_stack_client, for_generation=True, dataset_id="test_dataset_for_eval_2")
register_dataset(
llama_stack_client, for_generation=True, dataset_id="test_dataset_for_eval_2"
)
benchmark_id = str(uuid.uuid4())
llama_stack_client.benchmarks.register(
benchmark_id=benchmark_id,
@ -80,10 +84,14 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
},
)
assert response.job_id == "0"
job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id)
job_status = llama_stack_client.eval.jobs.status(
job_id=response.job_id, benchmark_id=benchmark_id
)
assert job_status and job_status == "completed"
eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id)
eval_response = llama_stack_client.eval.jobs.retrieve(
job_id=response.job_id, benchmark_id=benchmark_id
)
assert eval_response is not None
assert len(eval_response.generations) == 5
assert scoring_fn_id in eval_response.scores