pre-commit fixes

This commit is contained in:
Chantal D Gama Rose 2025-03-14 13:56:05 -07:00
parent 967dd0aa08
commit 7e211f8553
314 changed files with 5574 additions and 11369 deletions

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceEvalConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .eval import MetaReferenceEvalImpl

View file

@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
class MetaReferenceEvalConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
) # Uses SQLite config specific to Meta Reference Eval storage
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="meta_reference_eval.db",
)
}

View file

@ -83,7 +83,7 @@ class MetaReferenceEvalImpl(
async def run_eval(
self,
benchmark_id: str,
task_config: BenchmarkConfig,
benchmark_config: BenchmarkConfig,
) -> Job:
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
@ -92,13 +92,13 @@ class MetaReferenceEvalImpl(
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
task_config=task_config,
benchmark_config=benchmark_config,
)
# TODO: currently needs to wait for generation before returning
@ -108,9 +108,9 @@ class MetaReferenceEvalImpl(
return Job(job_id=job_id)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate
candidate = benchmark_config.eval_candidate
create_response = await self.agents_api.create_agent(candidate.config)
agent_id = create_response.agent_id
@ -151,9 +151,9 @@ class MetaReferenceEvalImpl(
return generations
async def _run_model_generation(
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate
candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
generations = []
@ -189,13 +189,13 @@ class MetaReferenceEvalImpl(
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
candidate = task_config.eval_candidate
candidate = benchmark_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, task_config)
generations = await self._run_agent_generation(input_rows, benchmark_config)
elif candidate.type == "model":
generations = await self._run_model_generation(input_rows, task_config)
generations = await self._run_model_generation(input_rows, benchmark_config)
else:
raise ValueError(f"Invalid candidate type: {candidate.type}")
@ -204,9 +204,9 @@ class MetaReferenceEvalImpl(
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
]
if task_config.scoring_params is not None:
if benchmark_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else: