mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
chore(api): add mypy coverage to cli
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
cd0ad21111
commit
66881b9cca
2 changed files with 34 additions and 21 deletions
|
@ -22,6 +22,7 @@ from rich.progress import (
|
||||||
BarColumn,
|
BarColumn,
|
||||||
DownloadColumn,
|
DownloadColumn,
|
||||||
Progress,
|
Progress,
|
||||||
|
TaskID,
|
||||||
TextColumn,
|
TextColumn,
|
||||||
TimeRemainingColumn,
|
TimeRemainingColumn,
|
||||||
TransferSpeedColumn,
|
TransferSpeedColumn,
|
||||||
|
@ -102,7 +103,7 @@ class DownloadTask:
|
||||||
output_file: str
|
output_file: str
|
||||||
total_size: int = 0
|
total_size: int = 0
|
||||||
downloaded_size: int = 0
|
downloaded_size: int = 0
|
||||||
task_id: int | None = None
|
task_id: TaskID | None = None
|
||||||
retries: int = 0
|
retries: int = 0
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
|
|
||||||
|
@ -139,13 +140,13 @@ class ParallelDownloader:
|
||||||
console=self.console,
|
console=self.console,
|
||||||
expand=True,
|
expand=True,
|
||||||
)
|
)
|
||||||
self.client_options = {
|
self.client_options: dict[str, object] = {
|
||||||
"timeout": httpx.Timeout(timeout),
|
"timeout": httpx.Timeout(timeout),
|
||||||
"follow_redirects": True,
|
"follow_redirects": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
|
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
|
||||||
last_exception = None
|
last_exception: Exception | None = None
|
||||||
for attempt in range(task.max_retries):
|
for attempt in range(task.max_retries):
|
||||||
try:
|
try:
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
@ -159,15 +160,18 @@ class ParallelDownloader:
|
||||||
)
|
)
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
continue
|
continue
|
||||||
|
if last_exception is not None:
|
||||||
raise last_exception
|
raise last_exception
|
||||||
|
raise RuntimeError("Retry failed without capturing exception")
|
||||||
|
|
||||||
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
||||||
if task.total_size > 0:
|
if task.total_size > 0:
|
||||||
|
if task.task_id is not None:
|
||||||
self.progress.update(task.task_id, total=task.total_size)
|
self.progress.update(task.task_id, total=task.total_size)
|
||||||
return
|
return
|
||||||
|
|
||||||
async def _get_info():
|
async def _get_info():
|
||||||
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
|
response = await client.head(task.url, headers={"Accept-Encoding": "identity"})
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -199,7 +203,7 @@ class ParallelDownloader:
|
||||||
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
|
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
|
||||||
async def _download_chunk():
|
async def _download_chunk():
|
||||||
headers = {"Range": f"bytes={start}-{end}"}
|
headers = {"Range": f"bytes={start}-{end}"}
|
||||||
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
|
async with client.stream("GET", task.url, headers=headers) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
with open(task.output_file, "ab") as file:
|
with open(task.output_file, "ab") as file:
|
||||||
|
@ -207,6 +211,7 @@ class ParallelDownloader:
|
||||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||||
file.write(chunk)
|
file.write(chunk)
|
||||||
task.downloaded_size += len(chunk)
|
task.downloaded_size += len(chunk)
|
||||||
|
if task.task_id is not None:
|
||||||
self.progress.update(
|
self.progress.update(
|
||||||
task.task_id,
|
task.task_id,
|
||||||
completed=task.downloaded_size,
|
completed=task.downloaded_size,
|
||||||
|
@ -228,13 +233,20 @@ class ParallelDownloader:
|
||||||
|
|
||||||
async def download_file(self, task: DownloadTask) -> None:
|
async def download_file(self, task: DownloadTask) -> None:
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(**self.client_options) as client:
|
client_timeout = self.client_options["timeout"]
|
||||||
|
if not isinstance(client_timeout, httpx.Timeout):
|
||||||
|
raise TypeError(f"Expected httpx.Timeout, got {type(client_timeout)}")
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=client_timeout,
|
||||||
|
follow_redirects=bool(self.client_options["follow_redirects"]),
|
||||||
|
) as client:
|
||||||
await self.get_file_info(client, task)
|
await self.get_file_info(client, task)
|
||||||
|
|
||||||
# Check if file is already downloaded
|
# Check if file is already downloaded
|
||||||
if os.path.exists(task.output_file):
|
if os.path.exists(task.output_file):
|
||||||
if self.verify_file_integrity(task):
|
if self.verify_file_integrity(task):
|
||||||
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
||||||
|
if task.task_id is not None:
|
||||||
self.progress.update(task.task_id, completed=task.total_size)
|
self.progress.update(task.task_id, completed=task.total_size)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -259,6 +271,7 @@ class ParallelDownloader:
|
||||||
raise DownloadError(f"Download failed: {str(e)}") from e
|
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if task.task_id is not None:
|
||||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||||
|
|
||||||
|
@ -349,7 +362,7 @@ def _hf_download(
|
||||||
except RepositoryNotFoundError:
|
except RepositoryNotFoundError:
|
||||||
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
|
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
parser.error(e)
|
parser.error(str(e))
|
||||||
|
|
||||||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||||
|
|
||||||
|
@ -465,13 +478,13 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
prompt_guard_model_sku_map,
|
prompt_guard_model_sku_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
|
prompt_guard_model_sku_map_dict = prompt_guard_model_sku_map()
|
||||||
prompt_guard_download_info_map = prompt_guard_download_info_map()
|
prompt_guard_download_info_map_dict = prompt_guard_download_info_map()
|
||||||
|
|
||||||
for model_id in model_ids:
|
for model_id in model_ids:
|
||||||
if model_id in prompt_guard_model_sku_map.keys():
|
if model_id in prompt_guard_model_sku_map_dict.keys():
|
||||||
model = prompt_guard_model_sku_map[model_id]
|
model = prompt_guard_model_sku_map_dict[model_id]
|
||||||
info = prompt_guard_download_info_map[model_id]
|
info = prompt_guard_download_info_map_dict[model_id]
|
||||||
else:
|
else:
|
||||||
model = resolve_model(model_id)
|
model = resolve_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -225,7 +225,7 @@ follow_imports = "silent"
|
||||||
# to exclude the entire directory.
|
# to exclude the entire directory.
|
||||||
exclude = [
|
exclude = [
|
||||||
# As we fix more and more of these, we should remove them from the list
|
# As we fix more and more of these, we should remove them from the list
|
||||||
"^llama_stack/cli/download\\.py$",
|
"^llama_stack/apis/common/training_types\\.py$",
|
||||||
"^llama_stack/cli/stack/_build\\.py$",
|
"^llama_stack/cli/stack/_build\\.py$",
|
||||||
"^llama_stack/distribution/build\\.py$",
|
"^llama_stack/distribution/build\\.py$",
|
||||||
"^llama_stack/distribution/client\\.py$",
|
"^llama_stack/distribution/client\\.py$",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue