mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-09 19:58:29 +00:00
fix cli download
This commit is contained in:
parent
76281f4de7
commit
234fe36d62
9 changed files with 34 additions and 33 deletions
8
docs/_static/llama-stack-spec.html
vendored
8
docs/_static/llama-stack-spec.html
vendored
|
@ -108,8 +108,8 @@
|
|||
"description": "",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"name": "eval_task_id",
|
||||
"in": "query",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
|
@ -3726,7 +3726,7 @@
|
|||
"DeprecatedRegisterEvalTaskRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"eval_task_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"dataset_id": {
|
||||
|
@ -3772,7 +3772,7 @@
|
|||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"task_id",
|
||||
"eval_task_id",
|
||||
"dataset_id",
|
||||
"scoring_functions"
|
||||
]
|
||||
|
|
8
docs/_static/llama-stack-spec.yaml
vendored
8
docs/_static/llama-stack-spec.yaml
vendored
|
@ -50,8 +50,8 @@ paths:
|
|||
- Benchmarks
|
||||
description: ''
|
||||
parameters:
|
||||
- name: task_id
|
||||
in: path
|
||||
- name: eval_task_id
|
||||
in: query
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -2248,7 +2248,7 @@ components:
|
|||
DeprecatedRegisterEvalTaskRequest:
|
||||
type: object
|
||||
properties:
|
||||
task_id:
|
||||
eval_task_id:
|
||||
type: string
|
||||
dataset_id:
|
||||
type: string
|
||||
|
@ -2272,7 +2272,7 @@ components:
|
|||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- task_id
|
||||
- eval_task_id
|
||||
- dataset_id
|
||||
- scoring_functions
|
||||
DeprecatedRunEvalRequest:
|
||||
|
|
|
@ -71,13 +71,13 @@ class Benchmarks(Protocol):
|
|||
@webmethod(route="/eval-tasks/{task_id}", method="GET")
|
||||
async def DEPRECATED_get_eval_task(
|
||||
self,
|
||||
task_id: str,
|
||||
eval_task_id: str,
|
||||
) -> Optional[Benchmark]: ...
|
||||
|
||||
@webmethod(route="/eval-tasks", method="POST")
|
||||
async def DEPRECATED_register_eval_task(
|
||||
self,
|
||||
task_id: str,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
|
|
|
@ -39,6 +39,7 @@ EvalCandidate = register_schema(
|
|||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
type: Literal["benchmark"] = "benchmark"
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
|
|
|
@ -105,7 +105,7 @@ class DownloadTask:
|
|||
output_file: str
|
||||
total_size: int = 0
|
||||
downloaded_size: int = 0
|
||||
benchmark_id: Optional[int] = None
|
||||
task_id: Optional[int] = None
|
||||
retries: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
|
@ -183,8 +183,8 @@ class ParallelDownloader:
|
|||
)
|
||||
|
||||
# Update the progress bar's total size once we know it
|
||||
if task.benchmark_id is not None:
|
||||
self.progress.update(task.benchmark_id, total=task.total_size)
|
||||
if task.task_id is not None:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
|
||||
|
@ -207,7 +207,7 @@ class ParallelDownloader:
|
|||
file.write(chunk)
|
||||
task.downloaded_size += len(chunk)
|
||||
self.progress.update(
|
||||
task.benchmark_id,
|
||||
task.task_id,
|
||||
completed=task.downloaded_size,
|
||||
)
|
||||
|
||||
|
@ -234,7 +234,7 @@ class ParallelDownloader:
|
|||
if os.path.exists(task.output_file):
|
||||
if self.verify_file_integrity(task):
|
||||
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
||||
self.progress.update(task.benchmark_id, completed=task.total_size)
|
||||
self.progress.update(task.task_id, completed=task.total_size)
|
||||
return
|
||||
|
||||
await self.prepare_download(task)
|
||||
|
@ -258,7 +258,7 @@ class ParallelDownloader:
|
|||
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
self.progress.update(task.benchmark_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||
|
||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
||||
|
@ -293,7 +293,7 @@ class ParallelDownloader:
|
|||
with self.progress:
|
||||
for task in tasks:
|
||||
desc = f"Downloading {Path(task.output_file).name}"
|
||||
task.benchmark_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
|
||||
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -
|
|||
) as progress:
|
||||
for filepath, expected_hash in checksums.items():
|
||||
full_path = model_dir / filepath
|
||||
benchmark_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
||||
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
||||
|
||||
exists = full_path.exists()
|
||||
actual_hash = None
|
||||
|
@ -102,7 +102,7 @@ def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -
|
|||
)
|
||||
)
|
||||
|
||||
progress.remove_task(benchmark_id)
|
||||
progress.remove_task(task_id)
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
@ -475,14 +475,14 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
|
||||
async def DEPRECATED_get_eval_task(
|
||||
self,
|
||||
task_id: str,
|
||||
eval_task_id: str,
|
||||
) -> Optional[Benchmark]:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.get_benchmark(task_id)
|
||||
return await self.get_benchmark(eval_task_id)
|
||||
|
||||
async def DEPRECATED_register_eval_task(
|
||||
self,
|
||||
task_id: str,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
|
@ -491,7 +491,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
) -> None:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.register_benchmark(
|
||||
benchmark_id=task_id,
|
||||
benchmark_id=eval_task_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
metadata=metadata,
|
||||
|
|
|
@ -205,7 +205,7 @@ class MetaReferenceEvalImpl(
|
|||
# scoring with generated_answer
|
||||
score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)]
|
||||
|
||||
if task_config.type == "app" and task_config.scoring_params is not None:
|
||||
if task_config.scoring_params is not None:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||
for scoring_fn_id in scoring_functions
|
||||
|
|
|
@ -59,14 +59,14 @@ class Testeval:
|
|||
scoring_functions = [
|
||||
"basic::equality",
|
||||
]
|
||||
task_id = "meta-reference::app_eval"
|
||||
benchmark_id = "meta-reference::app_eval"
|
||||
await benchmarks_impl.register_benchmark(
|
||||
benchmark_id=task_id,
|
||||
benchmark_id=benchmark_id,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
response = await eval_impl.evaluate_rows(
|
||||
task_id=task_id,
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=AppBenchmarkConfig(
|
||||
|
@ -105,14 +105,14 @@ class Testeval:
|
|||
"basic::subset_of",
|
||||
]
|
||||
|
||||
task_id = "meta-reference::app_eval-2"
|
||||
benchmark_id = "meta-reference::app_eval-2"
|
||||
await benchmarks_impl.register_benchmark(
|
||||
benchmark_id=task_id,
|
||||
benchmark_id=benchmark_id,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
response = await eval_impl.run_eval(
|
||||
task_id=task_id,
|
||||
benchmark_id=benchmark_id,
|
||||
task_config=AppBenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
model=inference_model,
|
||||
|
@ -121,9 +121,9 @@ class Testeval:
|
|||
),
|
||||
)
|
||||
assert response.job_id == "0"
|
||||
job_status = await eval_impl.job_status(task_id, response.job_id)
|
||||
job_status = await eval_impl.job_status(benchmark_id, response.job_id)
|
||||
assert job_status and job_status.value == "completed"
|
||||
eval_response = await eval_impl.job_result(task_id, response.job_id)
|
||||
eval_response = await eval_impl.job_result(benchmark_id, response.job_id)
|
||||
|
||||
assert eval_response is not None
|
||||
assert len(eval_response.generations) == 5
|
||||
|
@ -171,7 +171,7 @@ class Testeval:
|
|||
|
||||
benchmark_id = "meta-reference-mmlu"
|
||||
response = await eval_impl.run_eval(
|
||||
task_id=benchmark_id,
|
||||
benchmark_id=benchmark_id,
|
||||
task_config=BenchmarkBenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
model=inference_model,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue