forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -147,9 +147,7 @@ class ParallelDownloader:
|
|||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
async def retry_with_exponential_backoff(
|
||||
self, task: DownloadTask, func, *args, **kwargs
|
||||
):
|
||||
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
|
||||
last_exception = None
|
||||
for attempt in range(task.max_retries):
|
||||
try:
|
||||
|
@ -166,13 +164,9 @@ class ParallelDownloader:
|
|||
continue
|
||||
raise last_exception
|
||||
|
||||
async def get_file_info(
|
||||
self, client: httpx.AsyncClient, task: DownloadTask
|
||||
) -> None:
|
||||
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
||||
async def _get_info():
|
||||
response = await client.head(
|
||||
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
|
||||
)
|
||||
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
|
@ -201,14 +195,10 @@ class ParallelDownloader:
|
|||
return False
|
||||
return os.path.getsize(task.output_file) == task.total_size
|
||||
|
||||
async def download_chunk(
|
||||
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
|
||||
) -> None:
|
||||
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
|
||||
async def _download_chunk():
|
||||
headers = {"Range": f"bytes={start}-{end}"}
|
||||
async with client.stream(
|
||||
"GET", task.url, headers=headers, **self.client_options
|
||||
) as response:
|
||||
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
with open(task.output_file, "ab") as file:
|
||||
|
@ -225,8 +215,7 @@ class ParallelDownloader:
|
|||
await self.retry_with_exponential_backoff(task, _download_chunk)
|
||||
except Exception as e:
|
||||
raise DownloadError(
|
||||
f"Failed to download chunk {start}-{end} after "
|
||||
f"{task.max_retries} attempts: {str(e)}"
|
||||
f"Failed to download chunk {start}-{end} after {task.max_retries} attempts: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def prepare_download(self, task: DownloadTask) -> None:
|
||||
|
@ -244,9 +233,7 @@ class ParallelDownloader:
|
|||
# Check if file is already downloaded
|
||||
if os.path.exists(task.output_file):
|
||||
if self.verify_file_integrity(task):
|
||||
self.console.print(
|
||||
f"[green]Already downloaded {task.output_file}[/green]"
|
||||
)
|
||||
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
||||
self.progress.update(task.task_id, completed=task.total_size)
|
||||
return
|
||||
|
||||
|
@ -259,9 +246,7 @@ class ParallelDownloader:
|
|||
|
||||
current_pos = task.downloaded_size
|
||||
while current_pos < task.total_size:
|
||||
chunk_end = min(
|
||||
current_pos + chunk_size - 1, task.total_size - 1
|
||||
)
|
||||
chunk_end = min(current_pos + chunk_size - 1, task.total_size - 1)
|
||||
chunks.append((current_pos, chunk_end))
|
||||
current_pos = chunk_end + 1
|
||||
|
||||
|
@ -273,18 +258,12 @@ class ParallelDownloader:
|
|||
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
self.progress.update(
|
||||
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
|
||||
)
|
||||
raise DownloadError(
|
||||
f"Download failed for {task.output_file}: {str(e)}"
|
||||
) from e
|
||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||
|
||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
||||
try:
|
||||
total_remaining_size = sum(
|
||||
task.total_size - task.downloaded_size for task in tasks
|
||||
)
|
||||
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||
free_space = shutil.disk_usage(dir_path).free
|
||||
|
||||
|
@ -314,9 +293,7 @@ class ParallelDownloader:
|
|||
with self.progress:
|
||||
for task in tasks:
|
||||
desc = f"Downloading {Path(task.output_file).name}"
|
||||
task.task_id = self.progress.add_task(
|
||||
desc, total=task.total_size, completed=task.downloaded_size
|
||||
)
|
||||
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
||||
|
||||
|
@ -332,9 +309,7 @@ class ParallelDownloader:
|
|||
if failed_tasks:
|
||||
self.console.print("\n[red]Some downloads failed:[/red]")
|
||||
for task, error in failed_tasks:
|
||||
self.console.print(
|
||||
f"[red]- {Path(task.output_file).name}: {error}[/red]"
|
||||
)
|
||||
self.console.print(f"[red]- {Path(task.output_file).name}: {error}[/red]")
|
||||
raise DownloadError(f"{len(failed_tasks)} downloads failed")
|
||||
|
||||
|
||||
|
@ -396,11 +371,7 @@ def _meta_download(
|
|||
output_file = str(output_dir / f)
|
||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||
total_size = info.pth_size if "consolidated" in f else 0
|
||||
tasks.append(
|
||||
DownloadTask(
|
||||
url=url, output_file=output_file, total_size=total_size, max_retries=3
|
||||
)
|
||||
)
|
||||
tasks.append(DownloadTask(url=url, output_file=output_file, total_size=total_size, max_retries=3))
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
|
@ -446,14 +417,10 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
|||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if any(output_dir.iterdir()):
|
||||
console.print(
|
||||
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
|
||||
)
|
||||
console.print(f"[yellow]Output directory {output_dir} is not empty.[/yellow]")
|
||||
|
||||
while True:
|
||||
resp = input(
|
||||
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
|
||||
)
|
||||
resp = input("Do you want to (C)ontinue download or (R)estart completely? (continue/restart): ")
|
||||
if resp.lower() in ["restart", "r"]:
|
||||
shutil.rmtree(output_dir)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
@ -471,9 +438,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
|||
]
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(
|
||||
max_concurrent_downloads=max_concurrent_downloads
|
||||
)
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
|
||||
|
|
|
@ -47,33 +47,20 @@ class ModelPromptFormat(Subcommand):
|
|||
|
||||
# Only Llama 3.1 and 3.2 are supported
|
||||
supported_model_ids = [
|
||||
m
|
||||
for m in CoreModelId
|
||||
if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
]
|
||||
model_str = "\n".join([m.value for m in supported_model_ids])
|
||||
try:
|
||||
model_id = CoreModelId(args.model_name)
|
||||
except ValueError:
|
||||
self.parser.error(
|
||||
f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}"
|
||||
)
|
||||
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
|
||||
|
||||
if model_id not in supported_model_ids:
|
||||
self.parser.error(
|
||||
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
|
||||
)
|
||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||
|
||||
llama_3_1_file = (
|
||||
importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
|
||||
)
|
||||
llama_3_2_text_file = (
|
||||
importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
|
||||
)
|
||||
llama_3_2_vision_file = (
|
||||
importlib.resources.files("llama_models")
|
||||
/ "llama3_2/vision_prompt_format.md"
|
||||
)
|
||||
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
|
||||
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
|
||||
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
|
||||
if model_family(model_id) == ModelFamily.llama3_1:
|
||||
with importlib.resources.as_file(llama_3_1_file) as f:
|
||||
content = f.open("r").read()
|
||||
|
|
|
@ -17,16 +17,12 @@ class PromptGuardModel(BaseModel):
|
|||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||
|
||||
model_id: str = "Prompt-Guard-86M"
|
||||
description: str = (
|
||||
"Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
||||
)
|
||||
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
||||
is_featured: bool = False
|
||||
huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
|
||||
max_seq_length: int = 2048
|
||||
is_instruct_model: bool = False
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
||||
recommended_sampling_params: Optional[SamplingParams] = None
|
||||
|
||||
|
|
|
@ -56,9 +56,7 @@ def available_templates_specs() -> Dict[str, BuildConfig]:
|
|||
return template_specs
|
||||
|
||||
|
||||
def run_stack_build_command(
|
||||
parser: argparse.ArgumentParser, args: argparse.Namespace
|
||||
) -> None:
|
||||
def run_stack_build_command(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
|
||||
if args.list_templates:
|
||||
return _run_template_list_cmd()
|
||||
|
||||
|
@ -129,11 +127,7 @@ def run_stack_build_command(
|
|||
|
||||
providers = dict()
|
||||
for api, providers_for_api in get_provider_registry().items():
|
||||
available_providers = [
|
||||
x
|
||||
for x in providers_for_api.keys()
|
||||
if x not in ("remote", "remote::sample")
|
||||
]
|
||||
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
|
||||
api_provider = prompt(
|
||||
"> Enter provider for API {}: ".format(api.value),
|
||||
completer=WordCompleter(available_providers),
|
||||
|
@ -156,9 +150,7 @@ def run_stack_build_command(
|
|||
description=description,
|
||||
)
|
||||
|
||||
build_config = BuildConfig(
|
||||
image_type=image_type, distribution_spec=distribution_spec
|
||||
)
|
||||
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
|
||||
else:
|
||||
with open(args.config, "r") as f:
|
||||
try:
|
||||
|
@ -179,9 +171,7 @@ def run_stack_build_command(
|
|||
|
||||
if args.print_deps_only:
|
||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||
normal_deps, special_deps = get_provider_dependencies(
|
||||
build_config.distribution_spec.providers
|
||||
)
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
print(f"uv pip install {' '.join(normal_deps)}")
|
||||
for special_dep in special_deps:
|
||||
|
@ -206,9 +196,7 @@ def _generate_run_config(
|
|||
"""
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
run_config = StackRunConfig(
|
||||
container_image=(
|
||||
image_name if build_config.image_type == ImageType.container.value else None
|
||||
),
|
||||
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
|
||||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
|
@ -228,13 +216,9 @@ def _generate_run_config(
|
|||
if p.deprecation_error:
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
config_type = instantiate_class_type(
|
||||
provider_registry[Api(api)][provider_type].config_class
|
||||
)
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||
if hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(
|
||||
__distro_dir__=f"distributions/{image_name}"
|
||||
)
|
||||
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
|
||||
else:
|
||||
config = {}
|
||||
|
||||
|
@ -269,9 +253,7 @@ def _run_stack_build_command_from_build_config(
|
|||
image_name = f"distribution-{template_name}"
|
||||
else:
|
||||
if not image_name:
|
||||
raise ValueError(
|
||||
"Please specify an image name when building a container image without a template"
|
||||
)
|
||||
raise ValueError("Please specify an image name when building a container image without a template")
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a conda image")
|
||||
|
@ -299,10 +281,7 @@ def _run_stack_build_command_from_build_config(
|
|||
|
||||
if template_name:
|
||||
# copy run.yaml from template to build_dir instead of generating it again
|
||||
template_path = (
|
||||
importlib.resources.files("llama_stack")
|
||||
/ f"templates/{template_name}/run.yaml"
|
||||
)
|
||||
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
|
||||
with importlib.resources.as_file(template_path) as path:
|
||||
run_config_file = build_dir / f"{template_name}-run.yaml"
|
||||
shutil.copy(path, run_config_file)
|
||||
|
|
|
@ -82,31 +82,21 @@ class StackRun(Subcommand):
|
|||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = (
|
||||
Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
)
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to conda dir
|
||||
config_file = Path(
|
||||
BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml"
|
||||
)
|
||||
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to container dir
|
||||
config_file = Path(
|
||||
BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml"
|
||||
)
|
||||
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(
|
||||
DISTRIBS_BASE_DIR
|
||||
/ f"llamastack-{args.config}"
|
||||
/ f"{args.config}-run.yaml"
|
||||
)
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
|
@ -119,15 +109,8 @@ class StackRun(Subcommand):
|
|||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
|
||||
if config.container_image:
|
||||
script = (
|
||||
importlib.resources.files("llama_stack")
|
||||
/ "distribution/start_container.sh"
|
||||
)
|
||||
image_name = (
|
||||
f"distribution-{template_name}"
|
||||
if template_name
|
||||
else config.container_image
|
||||
)
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
|
||||
image_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
run_args = [script, image_name]
|
||||
else:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
|
@ -145,11 +128,7 @@ class StackRun(Subcommand):
|
|||
if env_name == "base":
|
||||
return os.environ.get("CONDA_PREFIX")
|
||||
# Get conda environments info
|
||||
conda_env_info = json.loads(
|
||||
subprocess.check_output(
|
||||
["conda", "info", "--envs", "--json"]
|
||||
).decode()
|
||||
)
|
||||
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
||||
envs = conda_env_info["envs"]
|
||||
for envpath in envs:
|
||||
if envpath.endswith(env_name):
|
||||
|
@ -173,10 +152,7 @@ class StackRun(Subcommand):
|
|||
)
|
||||
return
|
||||
|
||||
script = (
|
||||
importlib.resources.files("llama_stack")
|
||||
/ "distribution/start_conda_env.sh"
|
||||
)
|
||||
script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh"
|
||||
run_args = [
|
||||
script,
|
||||
image_name,
|
||||
|
|
|
@ -22,11 +22,7 @@ def format_row(row, col_widths):
|
|||
if line.strip() == "":
|
||||
lines.append("")
|
||||
else:
|
||||
lines.extend(
|
||||
textwrap.wrap(
|
||||
line, width, break_long_words=False, replace_whitespace=False
|
||||
)
|
||||
)
|
||||
lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False))
|
||||
return lines
|
||||
|
||||
wrapped = [wrap(item, width) for item, width in zip(row, col_widths)]
|
||||
|
|
|
@ -41,9 +41,7 @@ def up_to_date_config():
|
|||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
""".format(
|
||||
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
|
||||
)
|
||||
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
|
@ -83,9 +81,7 @@ def old_config():
|
|||
telemetry:
|
||||
provider_type: noop
|
||||
config: {{}}
|
||||
""".format(
|
||||
built_at=datetime.now().isoformat()
|
||||
)
|
||||
""".format(built_at=datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
|
@ -108,10 +104,7 @@ def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
|||
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||
result = parse_and_maybe_upgrade_config(old_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert all(
|
||||
api in result.providers
|
||||
for api in ["inference", "safety", "memory", "telemetry"]
|
||||
)
|
||||
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
|
||||
safety_provider = result.providers["safety"][0]
|
||||
assert safety_provider.provider_type == "meta-reference"
|
||||
assert "llama_guard_shield" in safety_provider.config
|
||||
|
|
|
@ -72,9 +72,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
|||
return checksums
|
||||
|
||||
|
||||
def verify_files(
|
||||
model_dir: Path, checksums: Dict[str, str], console: Console
|
||||
) -> List[VerificationResult]:
|
||||
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
|
||||
results = []
|
||||
|
||||
with Progress(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue