mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 04:28:02 +00:00
fix cli download
This commit is contained in:
parent
76281f4de7
commit
234fe36d62
9 changed files with 34 additions and 33 deletions
8
docs/_static/llama-stack-spec.html
vendored
8
docs/_static/llama-stack-spec.html
vendored
|
@ -108,8 +108,8 @@
|
||||||
"description": "",
|
"description": "",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
{
|
{
|
||||||
"name": "task_id",
|
"name": "eval_task_id",
|
||||||
"in": "path",
|
"in": "query",
|
||||||
"required": true,
|
"required": true,
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
@ -3726,7 +3726,7 @@
|
||||||
"DeprecatedRegisterEvalTaskRequest": {
|
"DeprecatedRegisterEvalTaskRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"task_id": {
|
"eval_task_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"dataset_id": {
|
"dataset_id": {
|
||||||
|
@ -3772,7 +3772,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"task_id",
|
"eval_task_id",
|
||||||
"dataset_id",
|
"dataset_id",
|
||||||
"scoring_functions"
|
"scoring_functions"
|
||||||
]
|
]
|
||||||
|
|
8
docs/_static/llama-stack-spec.yaml
vendored
8
docs/_static/llama-stack-spec.yaml
vendored
|
@ -50,8 +50,8 @@ paths:
|
||||||
- Benchmarks
|
- Benchmarks
|
||||||
description: ''
|
description: ''
|
||||||
parameters:
|
parameters:
|
||||||
- name: task_id
|
- name: eval_task_id
|
||||||
in: path
|
in: query
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
@ -2248,7 +2248,7 @@ components:
|
||||||
DeprecatedRegisterEvalTaskRequest:
|
DeprecatedRegisterEvalTaskRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
task_id:
|
eval_task_id:
|
||||||
type: string
|
type: string
|
||||||
dataset_id:
|
dataset_id:
|
||||||
type: string
|
type: string
|
||||||
|
@ -2272,7 +2272,7 @@ components:
|
||||||
- type: object
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- task_id
|
- eval_task_id
|
||||||
- dataset_id
|
- dataset_id
|
||||||
- scoring_functions
|
- scoring_functions
|
||||||
DeprecatedRunEvalRequest:
|
DeprecatedRunEvalRequest:
|
||||||
|
|
|
@ -71,13 +71,13 @@ class Benchmarks(Protocol):
|
||||||
@webmethod(route="/eval-tasks/{task_id}", method="GET")
|
@webmethod(route="/eval-tasks/{task_id}", method="GET")
|
||||||
async def DEPRECATED_get_eval_task(
|
async def DEPRECATED_get_eval_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
eval_task_id: str,
|
||||||
) -> Optional[Benchmark]: ...
|
) -> Optional[Benchmark]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval-tasks", method="POST")
|
@webmethod(route="/eval-tasks", method="POST")
|
||||||
async def DEPRECATED_register_eval_task(
|
async def DEPRECATED_register_eval_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
eval_task_id: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
provider_benchmark_id: Optional[str] = None,
|
provider_benchmark_id: Optional[str] = None,
|
||||||
|
|
|
@ -39,6 +39,7 @@ EvalCandidate = register_schema(
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BenchmarkConfig(BaseModel):
|
class BenchmarkConfig(BaseModel):
|
||||||
|
type: Literal["benchmark"] = "benchmark"
|
||||||
eval_candidate: EvalCandidate
|
eval_candidate: EvalCandidate
|
||||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
|
|
|
@ -105,7 +105,7 @@ class DownloadTask:
|
||||||
output_file: str
|
output_file: str
|
||||||
total_size: int = 0
|
total_size: int = 0
|
||||||
downloaded_size: int = 0
|
downloaded_size: int = 0
|
||||||
benchmark_id: Optional[int] = None
|
task_id: Optional[int] = None
|
||||||
retries: int = 0
|
retries: int = 0
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
|
|
||||||
|
@ -183,8 +183,8 @@ class ParallelDownloader:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the progress bar's total size once we know it
|
# Update the progress bar's total size once we know it
|
||||||
if task.benchmark_id is not None:
|
if task.task_id is not None:
|
||||||
self.progress.update(task.benchmark_id, total=task.total_size)
|
self.progress.update(task.task_id, total=task.total_size)
|
||||||
|
|
||||||
except httpx.HTTPError as e:
|
except httpx.HTTPError as e:
|
||||||
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
|
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
|
||||||
|
@ -207,7 +207,7 @@ class ParallelDownloader:
|
||||||
file.write(chunk)
|
file.write(chunk)
|
||||||
task.downloaded_size += len(chunk)
|
task.downloaded_size += len(chunk)
|
||||||
self.progress.update(
|
self.progress.update(
|
||||||
task.benchmark_id,
|
task.task_id,
|
||||||
completed=task.downloaded_size,
|
completed=task.downloaded_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -234,7 +234,7 @@ class ParallelDownloader:
|
||||||
if os.path.exists(task.output_file):
|
if os.path.exists(task.output_file):
|
||||||
if self.verify_file_integrity(task):
|
if self.verify_file_integrity(task):
|
||||||
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
||||||
self.progress.update(task.benchmark_id, completed=task.total_size)
|
self.progress.update(task.task_id, completed=task.total_size)
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.prepare_download(task)
|
await self.prepare_download(task)
|
||||||
|
@ -258,7 +258,7 @@ class ParallelDownloader:
|
||||||
raise DownloadError(f"Download failed: {str(e)}") from e
|
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.progress.update(task.benchmark_id, description=f"[red]Failed: {task.output_file}[/red]")
|
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||||
|
|
||||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
||||||
|
@ -293,7 +293,7 @@ class ParallelDownloader:
|
||||||
with self.progress:
|
with self.progress:
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
desc = f"Downloading {Path(task.output_file).name}"
|
desc = f"Downloading {Path(task.output_file).name}"
|
||||||
task.benchmark_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
|
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,7 @@ def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -
|
||||||
) as progress:
|
) as progress:
|
||||||
for filepath, expected_hash in checksums.items():
|
for filepath, expected_hash in checksums.items():
|
||||||
full_path = model_dir / filepath
|
full_path = model_dir / filepath
|
||||||
benchmark_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
||||||
|
|
||||||
exists = full_path.exists()
|
exists = full_path.exists()
|
||||||
actual_hash = None
|
actual_hash = None
|
||||||
|
@ -102,7 +102,7 @@ def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
progress.remove_task(benchmark_id)
|
progress.remove_task(task_id)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
|
@ -475,14 +475,14 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
|
||||||
async def DEPRECATED_get_eval_task(
|
async def DEPRECATED_get_eval_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
eval_task_id: str,
|
||||||
) -> Optional[Benchmark]:
|
) -> Optional[Benchmark]:
|
||||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||||
return await self.get_benchmark(task_id)
|
return await self.get_benchmark(eval_task_id)
|
||||||
|
|
||||||
async def DEPRECATED_register_eval_task(
|
async def DEPRECATED_register_eval_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
eval_task_id: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
provider_benchmark_id: Optional[str] = None,
|
provider_benchmark_id: Optional[str] = None,
|
||||||
|
@ -491,7 +491,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||||
return await self.register_benchmark(
|
return await self.register_benchmark(
|
||||||
benchmark_id=task_id,
|
benchmark_id=eval_task_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
|
|
@ -205,7 +205,7 @@ class MetaReferenceEvalImpl(
|
||||||
# scoring with generated_answer
|
# scoring with generated_answer
|
||||||
score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)]
|
score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)]
|
||||||
|
|
||||||
if task_config.type == "app" and task_config.scoring_params is not None:
|
if task_config.scoring_params is not None:
|
||||||
scoring_functions_dict = {
|
scoring_functions_dict = {
|
||||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||||
for scoring_fn_id in scoring_functions
|
for scoring_fn_id in scoring_functions
|
||||||
|
|
|
@ -59,14 +59,14 @@ class Testeval:
|
||||||
scoring_functions = [
|
scoring_functions = [
|
||||||
"basic::equality",
|
"basic::equality",
|
||||||
]
|
]
|
||||||
task_id = "meta-reference::app_eval"
|
benchmark_id = "meta-reference::app_eval"
|
||||||
await benchmarks_impl.register_benchmark(
|
await benchmarks_impl.register_benchmark(
|
||||||
benchmark_id=task_id,
|
benchmark_id=benchmark_id,
|
||||||
dataset_id="test_dataset_for_eval",
|
dataset_id="test_dataset_for_eval",
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
)
|
)
|
||||||
response = await eval_impl.evaluate_rows(
|
response = await eval_impl.evaluate_rows(
|
||||||
task_id=task_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=AppBenchmarkConfig(
|
task_config=AppBenchmarkConfig(
|
||||||
|
@ -105,14 +105,14 @@ class Testeval:
|
||||||
"basic::subset_of",
|
"basic::subset_of",
|
||||||
]
|
]
|
||||||
|
|
||||||
task_id = "meta-reference::app_eval-2"
|
benchmark_id = "meta-reference::app_eval-2"
|
||||||
await benchmarks_impl.register_benchmark(
|
await benchmarks_impl.register_benchmark(
|
||||||
benchmark_id=task_id,
|
benchmark_id=benchmark_id,
|
||||||
dataset_id="test_dataset_for_eval",
|
dataset_id="test_dataset_for_eval",
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
)
|
)
|
||||||
response = await eval_impl.run_eval(
|
response = await eval_impl.run_eval(
|
||||||
task_id=task_id,
|
benchmark_id=benchmark_id,
|
||||||
task_config=AppBenchmarkConfig(
|
task_config=AppBenchmarkConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model=inference_model,
|
model=inference_model,
|
||||||
|
@ -121,9 +121,9 @@ class Testeval:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert response.job_id == "0"
|
assert response.job_id == "0"
|
||||||
job_status = await eval_impl.job_status(task_id, response.job_id)
|
job_status = await eval_impl.job_status(benchmark_id, response.job_id)
|
||||||
assert job_status and job_status.value == "completed"
|
assert job_status and job_status.value == "completed"
|
||||||
eval_response = await eval_impl.job_result(task_id, response.job_id)
|
eval_response = await eval_impl.job_result(benchmark_id, response.job_id)
|
||||||
|
|
||||||
assert eval_response is not None
|
assert eval_response is not None
|
||||||
assert len(eval_response.generations) == 5
|
assert len(eval_response.generations) == 5
|
||||||
|
@ -171,7 +171,7 @@ class Testeval:
|
||||||
|
|
||||||
benchmark_id = "meta-reference-mmlu"
|
benchmark_id = "meta-reference-mmlu"
|
||||||
response = await eval_impl.run_eval(
|
response = await eval_impl.run_eval(
|
||||||
task_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
task_config=BenchmarkBenchmarkConfig(
|
task_config=BenchmarkBenchmarkConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model=inference_model,
|
model=inference_model,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue