mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
remove unecessary code
Signed-off-by: reidliu <reid201711@gmail.com>
This commit is contained in:
parent
8a0917a01b
commit
30f97a0de0
1 changed files with 19 additions and 25 deletions
|
@ -7,29 +7,10 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
|
|
||||||
def _ask_for_confirm(model) -> bool:
|
|
||||||
input_text = input(f"Are you sure you want to remove {model}? (y/n): ").strip().lower()
|
|
||||||
if input_text == "y":
|
|
||||||
return True
|
|
||||||
elif input_text == "n":
|
|
||||||
return False
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_model(model) -> None:
|
|
||||||
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, model)
|
|
||||||
if os.path.exists(model_path):
|
|
||||||
shutil.rmtree(model_path)
|
|
||||||
print(f"{model} removed.")
|
|
||||||
else:
|
|
||||||
print(f"{model} does not exist.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRemove(Subcommand):
|
class ModelRemove(Subcommand):
|
||||||
|
@ -51,7 +32,7 @@ class ModelRemove(Subcommand):
|
||||||
"-m",
|
"-m",
|
||||||
"--model",
|
"--model",
|
||||||
required=True,
|
required=True,
|
||||||
help="Specify the llama downloaded model name",
|
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"-f",
|
"-f",
|
||||||
|
@ -61,11 +42,24 @@ class ModelRemove(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from .safety_models import prompt_guard_model_sku
|
||||||
|
|
||||||
|
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model)
|
||||||
|
|
||||||
|
model_list = []
|
||||||
|
for model in all_registered_models() + [prompt_guard_model_sku()]:
|
||||||
|
model_list.append(model.descriptor().replace(":", "-"))
|
||||||
|
|
||||||
|
if args.model not in model_list or os.path.isdir(model_path):
|
||||||
|
print(f"'{args.model}' is not a valid llama model or does not exist.")
|
||||||
|
return
|
||||||
|
|
||||||
if args.force:
|
if args.force:
|
||||||
_remove_model(args.model)
|
shutil.rmtree(model_path)
|
||||||
|
print(f"{args.model} removed.")
|
||||||
else:
|
else:
|
||||||
confirm = _ask_for_confirm(args.model)
|
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
|
||||||
if confirm:
|
shutil.rmtree(model_path)
|
||||||
_remove_model(args.model)
|
print(f"{args.model} removed.")
|
||||||
else:
|
else:
|
||||||
print("Removal aborted.")
|
print("Removal aborted.")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue