Changes from the main repo

This commit is contained in:
Ashwin Bharambe 2024-07-19 16:11:17 -07:00
parent 9c9b834c0f
commit 7d2c0b14b8
8 changed files with 24 additions and 9 deletions

View file

@ -0,0 +1,9 @@
model_inference_config:
impl_type: "inline"
inline_config:
checkpoint_type: "pytorch"
checkpoint_dir: /home/cyni/local/llama-3
tokenizer_path: /home/cyni/local/llama-3/cl_toplang_128k
model_parallel_size: 1
max_seq_len: 2048
max_batch_size: 1

View file

@ -1,10 +1,11 @@
from enum import Enum
from typing import Any, Dict, Optional
from models.llama3_1.api.datatypes import URL
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
from models.llama3_1.api.datatypes import URL
@json_schema_type

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel
from pyopenapi import webmethod
from models.llama3_1.api.datatypes import * # noqa: F403
from models.llama3_1.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403
from toolchain.dataset.api.datatypes import * # noqa: F403
from toolchain.common.training_types import * # noqa: F403

View file

@ -93,7 +93,7 @@ class Llama:
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu")
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())

View file

@ -13,7 +13,8 @@ from toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type
class BuiltinShield(Enum):
llama_guard = "llama_guard"
prompt_guard = "prompt_guard"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
code_scanner_guard = "code_scanner_guard"
third_party_shield = "third_party_shield"
injection_shield = "injection_shield"

View file

@ -11,7 +11,11 @@ from .base import ( # noqa: F401
from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
transformers.logging.set_verbosity_error()