fix cli download

This commit is contained in:
Xi Yan 2025-02-12 21:38:40 -08:00
parent 76281f4de7
commit 234fe36d62
9 changed files with 34 additions and 33 deletions

View file

@ -108,8 +108,8 @@
"description": "",
"parameters": [
{
"name": "task_id",
"in": "path",
"name": "eval_task_id",
"in": "query",
"required": true,
"schema": {
"type": "string"
@ -3726,7 +3726,7 @@
"DeprecatedRegisterEvalTaskRequest": {
"type": "object",
"properties": {
"task_id": {
"eval_task_id": {
"type": "string"
},
"dataset_id": {
@ -3772,7 +3772,7 @@
},
"additionalProperties": false,
"required": [
"task_id",
"eval_task_id",
"dataset_id",
"scoring_functions"
]

View file

@ -50,8 +50,8 @@ paths:
- Benchmarks
description: ''
parameters:
- name: task_id
in: path
- name: eval_task_id
in: query
required: true
schema:
type: string
@ -2248,7 +2248,7 @@ components:
DeprecatedRegisterEvalTaskRequest:
type: object
properties:
task_id:
eval_task_id:
type: string
dataset_id:
type: string
@ -2272,7 +2272,7 @@ components:
- type: object
additionalProperties: false
required:
- task_id
- eval_task_id
- dataset_id
- scoring_functions
DeprecatedRunEvalRequest:

View file

@ -71,13 +71,13 @@ class Benchmarks(Protocol):
@webmethod(route="/eval-tasks/{task_id}", method="GET")
async def DEPRECATED_get_eval_task(
self,
task_id: str,
eval_task_id: str,
) -> Optional[Benchmark]: ...
@webmethod(route="/eval-tasks", method="POST")
async def DEPRECATED_register_eval_task(
self,
task_id: str,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,

View file

@ -39,6 +39,7 @@ EvalCandidate = register_schema(
@json_schema_type
class BenchmarkConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",

View file

@ -105,7 +105,7 @@ class DownloadTask:
output_file: str
total_size: int = 0
downloaded_size: int = 0
benchmark_id: Optional[int] = None
task_id: Optional[int] = None
retries: int = 0
max_retries: int = 3
@ -183,8 +183,8 @@ class ParallelDownloader:
)
# Update the progress bar's total size once we know it
if task.benchmark_id is not None:
self.progress.update(task.benchmark_id, total=task.total_size)
if task.task_id is not None:
self.progress.update(task.task_id, total=task.total_size)
except httpx.HTTPError as e:
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
@ -207,7 +207,7 @@ class ParallelDownloader:
file.write(chunk)
task.downloaded_size += len(chunk)
self.progress.update(
task.benchmark_id,
task.task_id,
completed=task.downloaded_size,
)
@ -234,7 +234,7 @@ class ParallelDownloader:
if os.path.exists(task.output_file):
if self.verify_file_integrity(task):
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
self.progress.update(task.benchmark_id, completed=task.total_size)
self.progress.update(task.task_id, completed=task.total_size)
return
await self.prepare_download(task)
@ -258,7 +258,7 @@ class ParallelDownloader:
raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e:
self.progress.update(task.benchmark_id, description=f"[red]Failed: {task.output_file}[/red]")
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
@ -293,7 +293,7 @@ class ParallelDownloader:
with self.progress:
for task in tasks:
desc = f"Downloading {Path(task.output_file).name}"
task.benchmark_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)

View file

@ -82,7 +82,7 @@ def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -
) as progress:
for filepath, expected_hash in checksums.items():
full_path = model_dir / filepath
benchmark_id = progress.add_task(f"Verifying {filepath}...", total=None)
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
exists = full_path.exists()
actual_hash = None
@ -102,7 +102,7 @@ def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -
)
)
progress.remove_task(benchmark_id)
progress.remove_task(task_id)
return results

View file

@ -475,14 +475,14 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def DEPRECATED_get_eval_task(
self,
task_id: str,
eval_task_id: str,
) -> Optional[Benchmark]:
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.get_benchmark(task_id)
return await self.get_benchmark(eval_task_id)
async def DEPRECATED_register_eval_task(
self,
task_id: str,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
@ -491,7 +491,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
) -> None:
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.register_benchmark(
benchmark_id=task_id,
benchmark_id=eval_task_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,

View file

@ -205,7 +205,7 @@ class MetaReferenceEvalImpl(
# scoring with generated_answer
score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)]
if task_config.type == "app" and task_config.scoring_params is not None:
if task_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions

View file

@ -59,14 +59,14 @@ class Testeval:
scoring_functions = [
"basic::equality",
]
task_id = "meta-reference::app_eval"
benchmark_id = "meta-reference::app_eval"
await benchmarks_impl.register_benchmark(
benchmark_id=task_id,
benchmark_id=benchmark_id,
dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions,
)
response = await eval_impl.evaluate_rows(
task_id=task_id,
benchmark_id=benchmark_id,
input_rows=rows.rows,
scoring_functions=scoring_functions,
task_config=AppBenchmarkConfig(
@ -105,14 +105,14 @@ class Testeval:
"basic::subset_of",
]
task_id = "meta-reference::app_eval-2"
benchmark_id = "meta-reference::app_eval-2"
await benchmarks_impl.register_benchmark(
benchmark_id=task_id,
benchmark_id=benchmark_id,
dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions,
)
response = await eval_impl.run_eval(
task_id=task_id,
benchmark_id=benchmark_id,
task_config=AppBenchmarkConfig(
eval_candidate=ModelCandidate(
model=inference_model,
@ -121,9 +121,9 @@ class Testeval:
),
)
assert response.job_id == "0"
job_status = await eval_impl.job_status(task_id, response.job_id)
job_status = await eval_impl.job_status(benchmark_id, response.job_id)
assert job_status and job_status.value == "completed"
eval_response = await eval_impl.job_result(task_id, response.job_id)
eval_response = await eval_impl.job_result(benchmark_id, response.job_id)
assert eval_response is not None
assert len(eval_response.generations) == 5
@ -171,7 +171,7 @@ class Testeval:
benchmark_id = "meta-reference-mmlu"
response = await eval_impl.run_eval(
task_id=benchmark_id,
benchmark_id=benchmark_id,
task_config=BenchmarkBenchmarkConfig(
eval_candidate=ModelCandidate(
model=inference_model,