forked from phoenix-oss/llama-stack-mirror
chore: rename task_config to benchmark_config (#1397)
# What does this PR do? - This was missed from previous deprecation: https://github.com/meta-llama/llama-stack/pull/1186 - Part of https://github.com/meta-llama/llama-stack/issues/1396 [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` pytest -v -s --nbval-lax ./llama-stack/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb ``` [//]: # (## Documentation)
This commit is contained in:
parent
158b6dc404
commit
e9a37bad63
12 changed files with 55 additions and 46 deletions
8
docs/_static/llama-stack-spec.html
vendored
8
docs/_static/llama-stack-spec.html
vendored
|
@ -6355,7 +6355,7 @@
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"task_config": {
|
"benchmark_config": {
|
||||||
"$ref": "#/components/schemas/BenchmarkConfig"
|
"$ref": "#/components/schemas/BenchmarkConfig"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -6363,7 +6363,7 @@
|
||||||
"required": [
|
"required": [
|
||||||
"input_rows",
|
"input_rows",
|
||||||
"scoring_functions",
|
"scoring_functions",
|
||||||
"task_config"
|
"benchmark_config"
|
||||||
],
|
],
|
||||||
"title": "EvaluateRowsRequest"
|
"title": "EvaluateRowsRequest"
|
||||||
},
|
},
|
||||||
|
@ -9248,13 +9248,13 @@
|
||||||
"RunEvalRequest": {
|
"RunEvalRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"task_config": {
|
"benchmark_config": {
|
||||||
"$ref": "#/components/schemas/BenchmarkConfig"
|
"$ref": "#/components/schemas/BenchmarkConfig"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"task_config"
|
"benchmark_config"
|
||||||
],
|
],
|
||||||
"title": "RunEvalRequest"
|
"title": "RunEvalRequest"
|
||||||
},
|
},
|
||||||
|
|
8
docs/_static/llama-stack-spec.yaml
vendored
8
docs/_static/llama-stack-spec.yaml
vendored
|
@ -4357,13 +4357,13 @@ components:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
task_config:
|
benchmark_config:
|
||||||
$ref: '#/components/schemas/BenchmarkConfig'
|
$ref: '#/components/schemas/BenchmarkConfig'
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- input_rows
|
- input_rows
|
||||||
- scoring_functions
|
- scoring_functions
|
||||||
- task_config
|
- benchmark_config
|
||||||
title: EvaluateRowsRequest
|
title: EvaluateRowsRequest
|
||||||
EvaluateResponse:
|
EvaluateResponse:
|
||||||
type: object
|
type: object
|
||||||
|
@ -6168,11 +6168,11 @@ components:
|
||||||
RunEvalRequest:
|
RunEvalRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
task_config:
|
benchmark_config:
|
||||||
$ref: '#/components/schemas/BenchmarkConfig'
|
$ref: '#/components/schemas/BenchmarkConfig'
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- task_config
|
- benchmark_config
|
||||||
title: RunEvalRequest
|
title: RunEvalRequest
|
||||||
Job:
|
Job:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -3675,7 +3675,7 @@
|
||||||
" benchmark_id=\"llama3.2-3B-instruct:tax_eval\",\n",
|
" benchmark_id=\"llama3.2-3B-instruct:tax_eval\",\n",
|
||||||
" input_rows=eval_rows.rows,\n",
|
" input_rows=eval_rows.rows,\n",
|
||||||
" scoring_functions=[\"braintrust::answer-similarity\"],\n",
|
" scoring_functions=[\"braintrust::answer-similarity\"],\n",
|
||||||
" task_config={\n",
|
" benchmark_config={\n",
|
||||||
" \"type\": \"benchmark\",\n",
|
" \"type\": \"benchmark\",\n",
|
||||||
" \"eval_candidate\": {\n",
|
" \"eval_candidate\": {\n",
|
||||||
" \"type\": \"model\",\n",
|
" \"type\": \"model\",\n",
|
||||||
|
@ -6383,7 +6383,7 @@
|
||||||
" benchmark_id=\"Llama-3.2-3B-Instruct-sft-0:tax_eval\",\n",
|
" benchmark_id=\"Llama-3.2-3B-Instruct-sft-0:tax_eval\",\n",
|
||||||
" input_rows=eval_rows.rows,\n",
|
" input_rows=eval_rows.rows,\n",
|
||||||
" scoring_functions=[\"braintrust::answer-similarity\"],\n",
|
" scoring_functions=[\"braintrust::answer-similarity\"],\n",
|
||||||
" task_config={\n",
|
" benchmark_config={\n",
|
||||||
" \"type\": \"benchmark\",\n",
|
" \"type\": \"benchmark\",\n",
|
||||||
" \"eval_candidate\": {\n",
|
" \"eval_candidate\": {\n",
|
||||||
" \"type\": \"model\",\n",
|
" \"type\": \"model\",\n",
|
||||||
|
|
|
@ -781,7 +781,7 @@
|
||||||
" benchmark_id=\"meta-reference::mmmu\",\n",
|
" benchmark_id=\"meta-reference::mmmu\",\n",
|
||||||
" input_rows=eval_rows,\n",
|
" input_rows=eval_rows,\n",
|
||||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||||
" task_config={\n",
|
" benchmark_config={\n",
|
||||||
" \"type\": \"benchmark\",\n",
|
" \"type\": \"benchmark\",\n",
|
||||||
" \"eval_candidate\": {\n",
|
" \"eval_candidate\": {\n",
|
||||||
" \"type\": \"model\",\n",
|
" \"type\": \"model\",\n",
|
||||||
|
@ -960,7 +960,7 @@
|
||||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||||
" input_rows=eval_rows.rows,\n",
|
" input_rows=eval_rows.rows,\n",
|
||||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||||
" task_config={\n",
|
" benchmark_config={\n",
|
||||||
" \"type\": \"benchmark\",\n",
|
" \"type\": \"benchmark\",\n",
|
||||||
" \"eval_candidate\": {\n",
|
" \"eval_candidate\": {\n",
|
||||||
" \"type\": \"model\",\n",
|
" \"type\": \"model\",\n",
|
||||||
|
@ -1109,7 +1109,7 @@
|
||||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||||
" input_rows=eval_rows.rows,\n",
|
" input_rows=eval_rows.rows,\n",
|
||||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||||
" task_config={\n",
|
" benchmark_config={\n",
|
||||||
" \"type\": \"benchmark\",\n",
|
" \"type\": \"benchmark\",\n",
|
||||||
" \"eval_candidate\": {\n",
|
" \"eval_candidate\": {\n",
|
||||||
" \"type\": \"agent\",\n",
|
" \"type\": \"agent\",\n",
|
||||||
|
|
|
@ -51,7 +51,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::mmmu",
|
benchmark_id="meta-reference::mmmu",
|
||||||
input_rows=eval_rows,
|
input_rows=eval_rows,
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
|
@ -109,7 +109,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
|
@ -158,7 +158,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "agent",
|
"type": "agent",
|
||||||
|
|
|
@ -19,7 +19,7 @@ response = client.benchmarks.register(
|
||||||
# Run evaluation
|
# Run evaluation
|
||||||
job = client.eval.run_eval(
|
job = client.eval.run_eval(
|
||||||
benchmark_id="my_eval",
|
benchmark_id="my_eval",
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "app",
|
"type": "app",
|
||||||
"eval_candidate": {"type": "agent", "config": agent_config},
|
"eval_candidate": {"type": "agent", "config": agent_config},
|
||||||
},
|
},
|
||||||
|
|
|
@ -87,7 +87,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::mmmu",
|
benchmark_id="meta-reference::mmmu",
|
||||||
input_rows=eval_rows,
|
input_rows=eval_rows,
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
|
@ -145,7 +145,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
|
@ -195,7 +195,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "agent",
|
"type": "agent",
|
||||||
|
|
|
@ -63,7 +63,7 @@ class Eval(Protocol):
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||||
|
@ -72,7 +72,7 @@ class Eval(Protocol):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||||
|
|
|
@ -81,7 +81,10 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
|
||||||
|
)
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
vector_db_id,
|
vector_db_id,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
@ -328,7 +331,10 @@ class DatasetIORouter(DatasetIO):
|
||||||
page_token: Optional[str] = None,
|
page_token: Optional[str] = None,
|
||||||
filter_condition: Optional[str] = None,
|
filter_condition: Optional[str] = None,
|
||||||
) -> PaginatedRowsResult:
|
) -> PaginatedRowsResult:
|
||||||
logcat.debug("core", f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}")
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=rows_in_page,
|
rows_in_page=rows_in_page,
|
||||||
|
@ -387,7 +393,10 @@ class ScoringRouter(Scoring):
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
logcat.debug("core", f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
|
||||||
|
)
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
@ -419,12 +428,12 @@ class EvalRouter(Eval):
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
task_config=task_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
|
@ -432,14 +441,14 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=task_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def job_status(
|
async def job_status(
|
||||||
|
|
|
@ -212,7 +212,7 @@ def run_evaluation_3():
|
||||||
benchmark_id=selected_benchmark,
|
benchmark_id=selected_benchmark,
|
||||||
input_rows=[r],
|
input_rows=[r],
|
||||||
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
||||||
task_config=benchmark_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
for k in r.keys():
|
for k in r.keys():
|
||||||
|
|
|
@ -83,7 +83,7 @@ class MetaReferenceEvalImpl(
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
task_def = self.benchmarks[benchmark_id]
|
task_def = self.benchmarks[benchmark_id]
|
||||||
dataset_id = task_def.dataset_id
|
dataset_id = task_def.dataset_id
|
||||||
|
@ -92,13 +92,13 @@ class MetaReferenceEvalImpl(
|
||||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
|
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||||
)
|
)
|
||||||
res = await self.evaluate_rows(
|
res = await self.evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=all_rows.rows,
|
input_rows=all_rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=task_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: currently needs to wait for generation before returning
|
# TODO: currently needs to wait for generation before returning
|
||||||
|
@ -108,9 +108,9 @@ class MetaReferenceEvalImpl(
|
||||||
return Job(job_id=job_id)
|
return Job(job_id=job_id)
|
||||||
|
|
||||||
async def _run_agent_generation(
|
async def _run_agent_generation(
|
||||||
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
|
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
candidate = task_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
create_response = await self.agents_api.create_agent(candidate.config)
|
create_response = await self.agents_api.create_agent(candidate.config)
|
||||||
agent_id = create_response.agent_id
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
@ -151,9 +151,9 @@ class MetaReferenceEvalImpl(
|
||||||
return generations
|
return generations
|
||||||
|
|
||||||
async def _run_model_generation(
|
async def _run_model_generation(
|
||||||
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
|
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
candidate = task_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||||
|
|
||||||
generations = []
|
generations = []
|
||||||
|
@ -189,13 +189,13 @@ class MetaReferenceEvalImpl(
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
candidate = task_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
if candidate.type == "agent":
|
if candidate.type == "agent":
|
||||||
generations = await self._run_agent_generation(input_rows, task_config)
|
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||||
elif candidate.type == "model":
|
elif candidate.type == "model":
|
||||||
generations = await self._run_model_generation(input_rows, task_config)
|
generations = await self._run_model_generation(input_rows, benchmark_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
||||||
|
|
||||||
|
@ -204,9 +204,9 @@ class MetaReferenceEvalImpl(
|
||||||
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
if task_config.scoring_params is not None:
|
if benchmark_config.scoring_params is not None:
|
||||||
scoring_functions_dict = {
|
scoring_functions_dict = {
|
||||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
|
||||||
for scoring_fn_id in scoring_functions
|
for scoring_fn_id in scoring_functions
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -68,7 +68,7 @@ class Testeval:
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=AppBenchmarkConfig(
|
benchmark_config=AppBenchmarkConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model=inference_model,
|
model=inference_model,
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
@ -111,7 +111,7 @@ class Testeval:
|
||||||
)
|
)
|
||||||
response = await eval_impl.run_eval(
|
response = await eval_impl.run_eval(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
task_config=AppBenchmarkConfig(
|
benchmark_config=AppBenchmarkConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model=inference_model,
|
model=inference_model,
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
@ -169,7 +169,7 @@ class Testeval:
|
||||||
benchmark_id = "meta-reference-mmlu"
|
benchmark_id = "meta-reference-mmlu"
|
||||||
response = await eval_impl.run_eval(
|
response = await eval_impl.run_eval(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
task_config=BenchmarkBenchmarkConfig(
|
benchmark_config=BenchmarkBenchmarkConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model=inference_model,
|
model=inference_model,
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue