api refactor

This commit is contained in:
Xi Yan 2024-11-07 13:54:26 -08:00
parent 97dcd5704c
commit 51c20f9c29
8 changed files with 64 additions and 59 deletions

View file

@ -76,13 +76,13 @@ class Testeval:
]
response = await eval_impl.run_eval(
eval_task_def=EvalTaskDef(
task=EvalTaskDef(
# NOTE: this is needed to make the router work for all app evals
identifier="meta-reference::app_eval",
dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions,
),
eval_task_config=AppEvalTaskConfig(
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(),

View file

@ -44,10 +44,10 @@ class TestScoring:
)
assert len(rows.rows) == 3
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
"meta-reference::equality",
]
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": None,
"meta-reference::equality": None,
}
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
@ -83,7 +83,7 @@ class TestScoring:
)
assert len(rows.rows) == 3
params = {
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
@ -91,13 +91,9 @@ class TestScoring:
)
}
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
]
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
scoring_params=params,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
@ -108,7 +104,6 @@ class TestScoring:
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
scoring_params=params,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions: