remove unecessary code

Signed-off-by: reidliu <reid201711@gmail.com>
This commit is contained in:
reidliu 2025-02-20 17:43:34 +08:00
parent 8a0917a01b
commit 30f97a0de0

View file

@ -7,29 +7,10 @@
import argparse
import os
import shutil
import sys
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
def _ask_for_confirm(model) -> bool:
input_text = input(f"Are you sure you want to remove {model}? (y/n): ").strip().lower()
if input_text == "y":
return True
elif input_text == "n":
return False
return False
def _remove_model(model) -> None:
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, model)
if os.path.exists(model_path):
shutil.rmtree(model_path)
print(f"{model} removed.")
else:
print(f"{model} does not exist.")
sys.exit(1)
from llama_stack.models.llama.sku_list import all_registered_models
class ModelRemove(Subcommand):
@ -51,7 +32,7 @@ class ModelRemove(Subcommand):
"-m",
"--model",
required=True,
help="Specify the llama downloaded model name",
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
)
self.parser.add_argument(
"-f",
@ -61,11 +42,24 @@ class ModelRemove(Subcommand):
)
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model)
model_list = []
for model in all_registered_models() + [prompt_guard_model_sku()]:
model_list.append(model.descriptor().replace(":", "-"))
if args.model not in model_list or os.path.isdir(model_path):
print(f"'{args.model}' is not a valid llama model or does not exist.")
return
if args.force:
_remove_model(args.model)
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
confirm = _ask_for_confirm(args.model)
if confirm:
_remove_model(args.model)
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
shutil.rmtree(model_path)
print(f"{args.model} removed.")
else:
print("Removal aborted.")