Add special case for prompt guard

This commit is contained in:
Ashwin Bharambe 2024-10-02 08:38:23 -07:00
parent a80b707ff8
commit cc5029a716
4 changed files with 76 additions and 13 deletions

View file

@ -38,9 +38,6 @@ class Download(Subcommand):
def setup_download_parser(parser: argparse.ArgumentParser) -> None:
from llama_models.sku_list import all_registered_models
models = all_registered_models()
parser.add_argument(
"--source",
choices=["meta", "huggingface"],
@ -123,16 +120,12 @@ def _hf_download(
print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
from llama_stack.distribution.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model)
# I believe we can use some concurrency here if needed but not sure it is worth it
for f in info.files:
output_file = str(output_dir / f)
@ -147,7 +140,9 @@ def _meta_download(model: "Model", meta_url: str):
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import resolve_model
from llama_models.sku_list import llama_meta_net_info, resolve_model
from .model.safety_models import prompt_guard_download_info, prompt_guard_model_sku
if args.manifest_file:
_download_from_manifest(args.manifest_file)
@ -157,7 +152,14 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
parser.error("Please provide a model id")
return
model = resolve_model(args.model_id)
prompt_guard = prompt_guard_model_sku()
if args.model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(args.model_id)
info = llama_meta_net_info(model)
if model is None:
parser.error(f"Model {args.model_id} not found")
return
@ -171,7 +173,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
assert meta_url is not None and "llamameta.net" in meta_url
_meta_download(model, meta_url)
_meta_download(model, meta_url, info)
class ModelEntry(BaseModel):