fix cli download

This commit is contained in:
Xi Yan 2025-02-12 21:38:40 -08:00
parent 76281f4de7
commit 234fe36d62
9 changed files with 34 additions and 33 deletions

View file

@ -108,8 +108,8 @@
"description": "", "description": "",
"parameters": [ "parameters": [
{ {
"name": "task_id", "name": "eval_task_id",
"in": "path", "in": "query",
"required": true, "required": true,
"schema": { "schema": {
"type": "string" "type": "string"
@ -3726,7 +3726,7 @@
"DeprecatedRegisterEvalTaskRequest": { "DeprecatedRegisterEvalTaskRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"task_id": { "eval_task_id": {
"type": "string" "type": "string"
}, },
"dataset_id": { "dataset_id": {
@ -3772,7 +3772,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"task_id", "eval_task_id",
"dataset_id", "dataset_id",
"scoring_functions" "scoring_functions"
] ]

View file

@ -50,8 +50,8 @@ paths:
- Benchmarks - Benchmarks
description: '' description: ''
parameters: parameters:
- name: task_id - name: eval_task_id
in: path in: query
required: true required: true
schema: schema:
type: string type: string
@ -2248,7 +2248,7 @@ components:
DeprecatedRegisterEvalTaskRequest: DeprecatedRegisterEvalTaskRequest:
type: object type: object
properties: properties:
task_id: eval_task_id:
type: string type: string
dataset_id: dataset_id:
type: string type: string
@ -2272,7 +2272,7 @@ components:
- type: object - type: object
additionalProperties: false additionalProperties: false
required: required:
- task_id - eval_task_id
- dataset_id - dataset_id
- scoring_functions - scoring_functions
DeprecatedRunEvalRequest: DeprecatedRunEvalRequest:

View file

@ -71,13 +71,13 @@ class Benchmarks(Protocol):
@webmethod(route="/eval-tasks/{task_id}", method="GET") @webmethod(route="/eval-tasks/{task_id}", method="GET")
async def DEPRECATED_get_eval_task( async def DEPRECATED_get_eval_task(
self, self,
task_id: str, eval_task_id: str,
) -> Optional[Benchmark]: ... ) -> Optional[Benchmark]: ...
@webmethod(route="/eval-tasks", method="POST") @webmethod(route="/eval-tasks", method="POST")
async def DEPRECATED_register_eval_task( async def DEPRECATED_register_eval_task(
self, self,
task_id: str, eval_task_id: str,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None, provider_benchmark_id: Optional[str] = None,

View file

@ -39,6 +39,7 @@ EvalCandidate = register_schema(
@json_schema_type @json_schema_type
class BenchmarkConfig(BaseModel): class BenchmarkConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field( scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run", description="Map between scoring function id and parameters for each scoring function you want to run",

View file

@ -105,7 +105,7 @@ class DownloadTask:
output_file: str output_file: str
total_size: int = 0 total_size: int = 0
downloaded_size: int = 0 downloaded_size: int = 0
benchmark_id: Optional[int] = None task_id: Optional[int] = None
retries: int = 0 retries: int = 0
max_retries: int = 3 max_retries: int = 3
@ -183,8 +183,8 @@ class ParallelDownloader:
) )
# Update the progress bar's total size once we know it # Update the progress bar's total size once we know it
if task.benchmark_id is not None: if task.task_id is not None:
self.progress.update(task.benchmark_id, total=task.total_size) self.progress.update(task.task_id, total=task.total_size)
except httpx.HTTPError as e: except httpx.HTTPError as e:
self.console.print(f"[red]Error getting file info: {str(e)}[/red]") self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
@ -207,7 +207,7 @@ class ParallelDownloader:
file.write(chunk) file.write(chunk)
task.downloaded_size += len(chunk) task.downloaded_size += len(chunk)
self.progress.update( self.progress.update(
task.benchmark_id, task.task_id,
completed=task.downloaded_size, completed=task.downloaded_size,
) )
@ -234,7 +234,7 @@ class ParallelDownloader:
if os.path.exists(task.output_file): if os.path.exists(task.output_file):
if self.verify_file_integrity(task): if self.verify_file_integrity(task):
self.console.print(f"[green]Already downloaded {task.output_file}[/green]") self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
self.progress.update(task.benchmark_id, completed=task.total_size) self.progress.update(task.task_id, completed=task.total_size)
return return
await self.prepare_download(task) await self.prepare_download(task)
@ -258,7 +258,7 @@ class ParallelDownloader:
raise DownloadError(f"Download failed: {str(e)}") from e raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e: except Exception as e:
self.progress.update(task.benchmark_id, description=f"[red]Failed: {task.output_file}[/red]") self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool: def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
@ -293,7 +293,7 @@ class ParallelDownloader:
with self.progress: with self.progress:
for task in tasks: for task in tasks:
desc = f"Downloading {Path(task.output_file).name}" desc = f"Downloading {Path(task.output_file).name}"
task.benchmark_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size) task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads) semaphore = asyncio.Semaphore(self.max_concurrent_downloads)

View file

@ -82,7 +82,7 @@ def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -
) as progress: ) as progress:
for filepath, expected_hash in checksums.items(): for filepath, expected_hash in checksums.items():
full_path = model_dir / filepath full_path = model_dir / filepath
benchmark_id = progress.add_task(f"Verifying {filepath}...", total=None) task_id = progress.add_task(f"Verifying {filepath}...", total=None)
exists = full_path.exists() exists = full_path.exists()
actual_hash = None actual_hash = None
@ -102,7 +102,7 @@ def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -
) )
) )
progress.remove_task(benchmark_id) progress.remove_task(task_id)
return results return results

View file

@ -475,14 +475,14 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def DEPRECATED_get_eval_task( async def DEPRECATED_get_eval_task(
self, self,
task_id: str, eval_task_id: str,
) -> Optional[Benchmark]: ) -> Optional[Benchmark]:
logger.warning("DEPRECATED: Use /eval/benchmarks instead") logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.get_benchmark(task_id) return await self.get_benchmark(eval_task_id)
async def DEPRECATED_register_eval_task( async def DEPRECATED_register_eval_task(
self, self,
task_id: str, eval_task_id: str,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None, provider_benchmark_id: Optional[str] = None,
@ -491,7 +491,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
) -> None: ) -> None:
logger.warning("DEPRECATED: Use /eval/benchmarks instead") logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.register_benchmark( return await self.register_benchmark(
benchmark_id=task_id, benchmark_id=eval_task_id,
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
metadata=metadata, metadata=metadata,

View file

@ -205,7 +205,7 @@ class MetaReferenceEvalImpl(
# scoring with generated_answer # scoring with generated_answer
score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)] score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)]
if task_config.type == "app" and task_config.scoring_params is not None: if task_config.scoring_params is not None:
scoring_functions_dict = { scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None) scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions for scoring_fn_id in scoring_functions

View file

@ -59,14 +59,14 @@ class Testeval:
scoring_functions = [ scoring_functions = [
"basic::equality", "basic::equality",
] ]
task_id = "meta-reference::app_eval" benchmark_id = "meta-reference::app_eval"
await benchmarks_impl.register_benchmark( await benchmarks_impl.register_benchmark(
benchmark_id=task_id, benchmark_id=benchmark_id,
dataset_id="test_dataset_for_eval", dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
) )
response = await eval_impl.evaluate_rows( response = await eval_impl.evaluate_rows(
task_id=task_id, benchmark_id=benchmark_id,
input_rows=rows.rows, input_rows=rows.rows,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
task_config=AppBenchmarkConfig( task_config=AppBenchmarkConfig(
@ -105,14 +105,14 @@ class Testeval:
"basic::subset_of", "basic::subset_of",
] ]
task_id = "meta-reference::app_eval-2" benchmark_id = "meta-reference::app_eval-2"
await benchmarks_impl.register_benchmark( await benchmarks_impl.register_benchmark(
benchmark_id=task_id, benchmark_id=benchmark_id,
dataset_id="test_dataset_for_eval", dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
) )
response = await eval_impl.run_eval( response = await eval_impl.run_eval(
task_id=task_id, benchmark_id=benchmark_id,
task_config=AppBenchmarkConfig( task_config=AppBenchmarkConfig(
eval_candidate=ModelCandidate( eval_candidate=ModelCandidate(
model=inference_model, model=inference_model,
@ -121,9 +121,9 @@ class Testeval:
), ),
) )
assert response.job_id == "0" assert response.job_id == "0"
job_status = await eval_impl.job_status(task_id, response.job_id) job_status = await eval_impl.job_status(benchmark_id, response.job_id)
assert job_status and job_status.value == "completed" assert job_status and job_status.value == "completed"
eval_response = await eval_impl.job_result(task_id, response.job_id) eval_response = await eval_impl.job_result(benchmark_id, response.job_id)
assert eval_response is not None assert eval_response is not None
assert len(eval_response.generations) == 5 assert len(eval_response.generations) == 5
@ -171,7 +171,7 @@ class Testeval:
benchmark_id = "meta-reference-mmlu" benchmark_id = "meta-reference-mmlu"
response = await eval_impl.run_eval( response = await eval_impl.run_eval(
task_id=benchmark_id, benchmark_id=benchmark_id,
task_config=BenchmarkBenchmarkConfig( task_config=BenchmarkBenchmarkConfig(
eval_candidate=ModelCandidate( eval_candidate=ModelCandidate(
model=inference_model, model=inference_model,