forked from phoenix-oss/llama-stack-mirror
Add special case for prompt guard
This commit is contained in:
parent
a80b707ff8
commit
cc5029a716
4 changed files with 76 additions and 13 deletions
|
@ -38,9 +38,6 @@ class Download(Subcommand):
|
|||
|
||||
|
||||
def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
models = all_registered_models()
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
choices=["meta", "huggingface"],
|
||||
|
@ -123,16 +120,12 @@ def _hf_download(
|
|||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||
|
||||
|
||||
def _meta_download(model: "Model", meta_url: str):
|
||||
from llama_models.sku_list import llama_meta_net_info
|
||||
|
||||
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(model_local_dir(model.descriptor()))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
||||
for f in info.files:
|
||||
output_file = str(output_dir / f)
|
||||
|
@ -147,7 +140,9 @@ def _meta_download(model: "Model", meta_url: str):
|
|||
|
||||
|
||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
||||
|
||||
from .model.safety_models import prompt_guard_download_info, prompt_guard_model_sku
|
||||
|
||||
if args.manifest_file:
|
||||
_download_from_manifest(args.manifest_file)
|
||||
|
@ -157,7 +152,14 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
parser.error("Please provide a model id")
|
||||
return
|
||||
|
||||
model = resolve_model(args.model_id)
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
if args.model_id == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
info = prompt_guard_download_info()
|
||||
else:
|
||||
model = resolve_model(args.model_id)
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
if model is None:
|
||||
parser.error(f"Model {args.model_id} not found")
|
||||
return
|
||||
|
@ -171,7 +173,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
assert meta_url is not None and "llamameta.net" in meta_url
|
||||
_meta_download(model, meta_url)
|
||||
_meta_download(model, meta_url, info)
|
||||
|
||||
|
||||
class ModelEntry(BaseModel):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue