Changes from the main repo

This commit is contained in:
Ashwin Bharambe 2024-07-19 16:11:17 -07:00
parent 9c9b834c0f
commit 7d2c0b14b8
8 changed files with 24 additions and 9 deletions

View file

@ -0,0 +1,9 @@
model_inference_config:
impl_type: "inline"
inline_config:
checkpoint_type: "pytorch"
checkpoint_dir: /home/cyni/local/llama-3
tokenizer_path: /home/cyni/local/llama-3/cl_toplang_128k
model_parallel_size: 1
max_seq_len: 2048
max_batch_size: 1

View file

@ -1,10 +1,11 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from models.llama3_1.api.datatypes import URL
from pydantic import BaseModel from pydantic import BaseModel
from strong_typing.schema import json_schema_type from strong_typing.schema import json_schema_type
from models.llama3_1.api.datatypes import URL
@json_schema_type @json_schema_type

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel
from pyopenapi import webmethod from pyopenapi import webmethod
from models.llama3_1.api.datatypes import * # noqa: F403 from models.llama3_1.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403
from toolchain.dataset.api.datatypes import * # noqa: F403 from toolchain.dataset.api.datatypes import * # noqa: F403
from toolchain.common.training_types import * # noqa: F403 from toolchain.common.training_types import * # noqa: F403

View file

@ -93,7 +93,7 @@ class Llama:
checkpoints checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()] ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu") checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read()) params = json.loads(f.read())

View file

@ -13,7 +13,8 @@ from toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type @json_schema_type
class BuiltinShield(Enum): class BuiltinShield(Enum):
llama_guard = "llama_guard" llama_guard = "llama_guard"
prompt_guard = "prompt_guard" injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
code_scanner_guard = "code_scanner_guard" code_scanner_guard = "code_scanner_guard"
third_party_shield = "third_party_shield" third_party_shield = "third_party_shield"
injection_shield = "injection_shield" injection_shield = "injection_shield"

View file

@ -11,7 +11,11 @@ from .base import ( # noqa: F401
from .code_scanner import CodeScannerShield # noqa: F401 from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401 from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401 from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401 from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401 from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/bin/bash
set -euo pipefail set -euo pipefail
@ -10,10 +10,10 @@ rootdir=$(git rev-parse --show-toplevel)
files_to_copy=("toolchain/spec/openapi*" "models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py") files_to_copy=("toolchain/spec/openapi*" "models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py")
for file in "${files_to_copy[@]}"; do for file in "${files_to_copy[@]}"; do
relpath="$file" relpath="$file"
set -x set -x
mkdir -p "$TMPDIR/$(dirname $relpath)" mkdir -p "$TMPDIR/$(dirname $relpath)"
eval cp "$rootdir/$relpath" "$TMPDIR/$(dirname $relpath)" eval cp "$rootdir/$relpath" "$TMPDIR/$(dirname $relpath)"
set +x set +x
done done
cd "$TMPDIR" cd "$TMPDIR"

View file

@ -1,5 +1,5 @@
#!/bin/bash #!/bin/bash
set -x set -x
PYTHONPATH=../../../oss-ops:../.. python3 -m toolchain.spec.generate PYTHONPATH=../../../oss-ops:../.. python3 -m toolchain.spec.generate