Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -147,9 +147,7 @@ class ParallelDownloader:
"follow_redirects": True,
}
async def retry_with_exponential_backoff(
self, task: DownloadTask, func, *args, **kwargs
):
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
last_exception = None
for attempt in range(task.max_retries):
try:
@ -166,13 +164,9 @@ class ParallelDownloader:
continue
raise last_exception
async def get_file_info(
self, client: httpx.AsyncClient, task: DownloadTask
) -> None:
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
async def _get_info():
response = await client.head(
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
)
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
response.raise_for_status()
return response
@ -201,14 +195,10 @@ class ParallelDownloader:
return False
return os.path.getsize(task.output_file) == task.total_size
async def download_chunk(
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
) -> None:
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
async def _download_chunk():
headers = {"Range": f"bytes={start}-{end}"}
async with client.stream(
"GET", task.url, headers=headers, **self.client_options
) as response:
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
response.raise_for_status()
with open(task.output_file, "ab") as file:
@ -225,8 +215,7 @@ class ParallelDownloader:
await self.retry_with_exponential_backoff(task, _download_chunk)
except Exception as e:
raise DownloadError(
f"Failed to download chunk {start}-{end} after "
f"{task.max_retries} attempts: {str(e)}"
f"Failed to download chunk {start}-{end} after {task.max_retries} attempts: {str(e)}"
) from e
async def prepare_download(self, task: DownloadTask) -> None:
@ -244,9 +233,7 @@ class ParallelDownloader:
# Check if file is already downloaded
if os.path.exists(task.output_file):
if self.verify_file_integrity(task):
self.console.print(
f"[green]Already downloaded {task.output_file}[/green]"
)
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
self.progress.update(task.task_id, completed=task.total_size)
return
@ -259,9 +246,7 @@ class ParallelDownloader:
current_pos = task.downloaded_size
while current_pos < task.total_size:
chunk_end = min(
current_pos + chunk_size - 1, task.total_size - 1
)
chunk_end = min(current_pos + chunk_size - 1, task.total_size - 1)
chunks.append((current_pos, chunk_end))
current_pos = chunk_end + 1
@ -273,18 +258,12 @@ class ParallelDownloader:
raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e:
self.progress.update(
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
)
raise DownloadError(
f"Download failed for {task.output_file}: {str(e)}"
) from e
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
try:
total_remaining_size = sum(
task.total_size - task.downloaded_size for task in tasks
)
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
free_space = shutil.disk_usage(dir_path).free
@ -314,9 +293,7 @@ class ParallelDownloader:
with self.progress:
for task in tasks:
desc = f"Downloading {Path(task.output_file).name}"
task.task_id = self.progress.add_task(
desc, total=task.total_size, completed=task.downloaded_size
)
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
@ -332,9 +309,7 @@ class ParallelDownloader:
if failed_tasks:
self.console.print("\n[red]Some downloads failed:[/red]")
for task, error in failed_tasks:
self.console.print(
f"[red]- {Path(task.output_file).name}: {error}[/red]"
)
self.console.print(f"[red]- {Path(task.output_file).name}: {error}[/red]")
raise DownloadError(f"{len(failed_tasks)} downloads failed")
@ -396,11 +371,7 @@ def _meta_download(
output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0
tasks.append(
DownloadTask(
url=url, output_file=output_file, total_size=total_size, max_retries=3
)
)
tasks.append(DownloadTask(url=url, output_file=output_file, total_size=total_size, max_retries=3))
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
@ -446,14 +417,10 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
console.print(
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
)
console.print(f"[yellow]Output directory {output_dir} is not empty.[/yellow]")
while True:
resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
resp = input("Do you want to (C)ontinue download or (R)estart completely? (continue/restart): ")
if resp.lower() in ["restart", "r"]:
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
@ -471,9 +438,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
]
# Initialize and run parallel downloader
downloader = ParallelDownloader(
max_concurrent_downloads=max_concurrent_downloads
)
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))

View file

@ -47,33 +47,20 @@ class ModelPromptFormat(Subcommand):
# Only Llama 3.1 and 3.2 are supported
supported_model_ids = [
m
for m in CoreModelId
if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
]
model_str = "\n".join([m.value for m in supported_model_ids])
try:
model_id = CoreModelId(args.model_name)
except ValueError:
self.parser.error(
f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}"
)
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
if model_id not in supported_model_ids:
self.parser.error(
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
)
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
llama_3_1_file = (
importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
)
llama_3_2_text_file = (
importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
)
llama_3_2_vision_file = (
importlib.resources.files("llama_models")
/ "llama3_2/vision_prompt_format.md"
)
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
if model_family(model_id) == ModelFamily.llama3_1:
with importlib.resources.as_file(llama_3_1_file) as f:
content = f.open("r").read()

View file

@ -17,16 +17,12 @@ class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
model_id: str = "Prompt-Guard-86M"
description: str = (
"Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
)
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
is_featured: bool = False
huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
max_seq_length: int = 2048
is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict)
recommended_sampling_params: Optional[SamplingParams] = None

View file

@ -56,9 +56,7 @@ def available_templates_specs() -> Dict[str, BuildConfig]:
return template_specs
def run_stack_build_command(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> None:
def run_stack_build_command(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
if args.list_templates:
return _run_template_list_cmd()
@ -129,11 +127,7 @@ def run_stack_build_command(
providers = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [
x
for x in providers_for_api.keys()
if x not in ("remote", "remote::sample")
]
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
api_provider = prompt(
"> Enter provider for API {}: ".format(api.value),
completer=WordCompleter(available_providers),
@ -156,9 +150,7 @@ def run_stack_build_command(
description=description,
)
build_config = BuildConfig(
image_type=image_type, distribution_spec=distribution_spec
)
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
else:
with open(args.config, "r") as f:
try:
@ -179,9 +171,7 @@ def run_stack_build_command(
if args.print_deps_only:
print(f"# Dependencies for {args.template or args.config or image_name}")
normal_deps, special_deps = get_provider_dependencies(
build_config.distribution_spec.providers
)
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps += SERVER_DEPENDENCIES
print(f"uv pip install {' '.join(normal_deps)}")
for special_dep in special_deps:
@ -206,9 +196,7 @@ def _generate_run_config(
"""
apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig(
container_image=(
image_name if build_config.image_type == ImageType.container.value else None
),
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
image_name=image_name,
apis=apis,
providers={},
@ -228,13 +216,9 @@ def _generate_run_config(
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
config_type = instantiate_class_type(
provider_registry[Api(api)][provider_type].config_class
)
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(
__distro_dir__=f"distributions/{image_name}"
)
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
else:
config = {}
@ -269,9 +253,7 @@ def _run_stack_build_command_from_build_config(
image_name = f"distribution-{template_name}"
else:
if not image_name:
raise ValueError(
"Please specify an image name when building a container image without a template"
)
raise ValueError("Please specify an image name when building a container image without a template")
elif build_config.image_type == ImageType.conda.value:
if not image_name:
raise ValueError("Please specify an image name when building a conda image")
@ -299,10 +281,7 @@ def _run_stack_build_command_from_build_config(
if template_name:
# copy run.yaml from template to build_dir instead of generating it again
template_path = (
importlib.resources.files("llama_stack")
/ f"templates/{template_name}/run.yaml"
)
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
with importlib.resources.as_file(template_path) as path:
run_config_file = build_dir / f"{template_name}-run.yaml"
shutil.copy(path, run_config_file)

View file

@ -82,31 +82,21 @@ class StackRun(Subcommand):
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
)
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to conda dir
config_file = Path(
BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml"
)
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to container dir
config_file = Path(
BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml"
)
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(
DISTRIBS_BASE_DIR
/ f"llamastack-{args.config}"
/ f"{args.config}-run.yaml"
)
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
@ -119,15 +109,8 @@ class StackRun(Subcommand):
config = parse_and_maybe_upgrade_config(config_dict)
if config.container_image:
script = (
importlib.resources.files("llama_stack")
/ "distribution/start_container.sh"
)
image_name = (
f"distribution-{template_name}"
if template_name
else config.container_image
)
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
image_name = f"distribution-{template_name}" if template_name else config.container_image
run_args = [script, image_name]
else:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
@ -145,11 +128,7 @@ class StackRun(Subcommand):
if env_name == "base":
return os.environ.get("CONDA_PREFIX")
# Get conda environments info
conda_env_info = json.loads(
subprocess.check_output(
["conda", "info", "--envs", "--json"]
).decode()
)
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"]
for envpath in envs:
if envpath.endswith(env_name):
@ -173,10 +152,7 @@ class StackRun(Subcommand):
)
return
script = (
importlib.resources.files("llama_stack")
/ "distribution/start_conda_env.sh"
)
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
run_args = [
script,
image_name,

View file

@ -22,11 +22,7 @@ def format_row(row, col_widths):
if line.strip() == "":
lines.append("")
else:
lines.extend(
textwrap.wrap(
line, width, break_long_words=False, replace_whitespace=False
)
)
lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False))
return lines
wrapped = [wrap(item, width) for item, width in zip(row, col_widths)]

View file

@ -41,9 +41,7 @@ def up_to_date_config():
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
""".format(
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
)
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
)
@ -83,9 +81,7 @@ def old_config():
telemetry:
provider_type: noop
config: {{}}
""".format(
built_at=datetime.now().isoformat()
)
""".format(built_at=datetime.now().isoformat())
)
@ -108,10 +104,7 @@ def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
def test_parse_and_maybe_upgrade_config_old_format(old_config):
result = parse_and_maybe_upgrade_config(old_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert all(
api in result.providers
for api in ["inference", "safety", "memory", "telemetry"]
)
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
safety_provider = result.providers["safety"][0]
assert safety_provider.provider_type == "meta-reference"
assert "llama_guard_shield" in safety_provider.config

View file

@ -72,9 +72,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
return checksums
def verify_files(
model_dir: Path, checksums: Dict[str, str], console: Console
) -> List[VerificationResult]:
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
results = []
with Progress(