mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
chore(api): add mypy coverage to cli
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
cd0ad21111
commit
66881b9cca
2 changed files with 34 additions and 21 deletions
|
@ -22,6 +22,7 @@ from rich.progress import (
|
|||
BarColumn,
|
||||
DownloadColumn,
|
||||
Progress,
|
||||
TaskID,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
TransferSpeedColumn,
|
||||
|
@ -102,7 +103,7 @@ class DownloadTask:
|
|||
output_file: str
|
||||
total_size: int = 0
|
||||
downloaded_size: int = 0
|
||||
task_id: int | None = None
|
||||
task_id: TaskID | None = None
|
||||
retries: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
|
@ -139,13 +140,13 @@ class ParallelDownloader:
|
|||
console=self.console,
|
||||
expand=True,
|
||||
)
|
||||
self.client_options = {
|
||||
self.client_options: dict[str, object] = {
|
||||
"timeout": httpx.Timeout(timeout),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
|
||||
last_exception = None
|
||||
last_exception: Exception | None = None
|
||||
for attempt in range(task.max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
@ -159,15 +160,18 @@ class ParallelDownloader:
|
|||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
raise last_exception
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
raise RuntimeError("Retry failed without capturing exception")
|
||||
|
||||
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
||||
if task.total_size > 0:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
if task.task_id is not None:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
return
|
||||
|
||||
async def _get_info():
|
||||
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
|
||||
response = await client.head(task.url, headers={"Accept-Encoding": "identity"})
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
|
@ -199,7 +203,7 @@ class ParallelDownloader:
|
|||
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
|
||||
async def _download_chunk():
|
||||
headers = {"Range": f"bytes={start}-{end}"}
|
||||
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
|
||||
async with client.stream("GET", task.url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
with open(task.output_file, "ab") as file:
|
||||
|
@ -207,10 +211,11 @@ class ParallelDownloader:
|
|||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
task.downloaded_size += len(chunk)
|
||||
self.progress.update(
|
||||
task.task_id,
|
||||
completed=task.downloaded_size,
|
||||
)
|
||||
if task.task_id is not None:
|
||||
self.progress.update(
|
||||
task.task_id,
|
||||
completed=task.downloaded_size,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.retry_with_exponential_backoff(task, _download_chunk)
|
||||
|
@ -228,14 +233,21 @@ class ParallelDownloader:
|
|||
|
||||
async def download_file(self, task: DownloadTask) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(**self.client_options) as client:
|
||||
client_timeout = self.client_options["timeout"]
|
||||
if not isinstance(client_timeout, httpx.Timeout):
|
||||
raise TypeError(f"Expected httpx.Timeout, got {type(client_timeout)}")
|
||||
async with httpx.AsyncClient(
|
||||
timeout=client_timeout,
|
||||
follow_redirects=bool(self.client_options["follow_redirects"]),
|
||||
) as client:
|
||||
await self.get_file_info(client, task)
|
||||
|
||||
# Check if file is already downloaded
|
||||
if os.path.exists(task.output_file):
|
||||
if self.verify_file_integrity(task):
|
||||
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
||||
self.progress.update(task.task_id, completed=task.total_size)
|
||||
if task.task_id is not None:
|
||||
self.progress.update(task.task_id, completed=task.total_size)
|
||||
return
|
||||
|
||||
await self.prepare_download(task)
|
||||
|
@ -259,7 +271,8 @@ class ParallelDownloader:
|
|||
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||
if task.task_id is not None:
|
||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||
|
||||
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
|
||||
|
@ -349,7 +362,7 @@ def _hf_download(
|
|||
except RepositoryNotFoundError:
|
||||
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
|
||||
except Exception as e:
|
||||
parser.error(e)
|
||||
parser.error(str(e))
|
||||
|
||||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||
|
||||
|
@ -465,13 +478,13 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
prompt_guard_model_sku_map,
|
||||
)
|
||||
|
||||
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
|
||||
prompt_guard_download_info_map = prompt_guard_download_info_map()
|
||||
prompt_guard_model_sku_map_dict = prompt_guard_model_sku_map()
|
||||
prompt_guard_download_info_map_dict = prompt_guard_download_info_map()
|
||||
|
||||
for model_id in model_ids:
|
||||
if model_id in prompt_guard_model_sku_map.keys():
|
||||
model = prompt_guard_model_sku_map[model_id]
|
||||
info = prompt_guard_download_info_map[model_id]
|
||||
if model_id in prompt_guard_model_sku_map_dict.keys():
|
||||
model = prompt_guard_model_sku_map_dict[model_id]
|
||||
info = prompt_guard_download_info_map_dict[model_id]
|
||||
else:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
|
|
|
@ -225,7 +225,7 @@ follow_imports = "silent"
|
|||
# to exclude the entire directory.
|
||||
exclude = [
|
||||
# As we fix more and more of these, we should remove them from the list
|
||||
"^llama_stack/cli/download\\.py$",
|
||||
"^llama_stack/apis/common/training_types\\.py$",
|
||||
"^llama_stack/cli/stack/_build\\.py$",
|
||||
"^llama_stack/distribution/build\\.py$",
|
||||
"^llama_stack/distribution/client\\.py$",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue