mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Changes from the main repo
This commit is contained in:
parent
9c9b834c0f
commit
7d2c0b14b8
8 changed files with 24 additions and 9 deletions
9
toolchain/configs/cyni.yaml
Normal file
9
toolchain/configs/cyni.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
model_inference_config:
|
||||||
|
impl_type: "inline"
|
||||||
|
inline_config:
|
||||||
|
checkpoint_type: "pytorch"
|
||||||
|
checkpoint_dir: /home/cyni/local/llama-3
|
||||||
|
tokenizer_path: /home/cyni/local/llama-3/cl_toplang_128k
|
||||||
|
model_parallel_size: 1
|
||||||
|
max_seq_len: 2048
|
||||||
|
max_batch_size: 1
|
|
@ -1,10 +1,11 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from models.llama3_1.api.datatypes import URL
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from strong_typing.schema import json_schema_type
|
from strong_typing.schema import json_schema_type
|
||||||
from models.llama3_1.api.datatypes import URL
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -93,7 +93,7 @@ class Llama:
|
||||||
checkpoints
|
checkpoints
|
||||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
params = json.loads(f.read())
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,8 @@ from toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BuiltinShield(Enum):
|
class BuiltinShield(Enum):
|
||||||
llama_guard = "llama_guard"
|
llama_guard = "llama_guard"
|
||||||
prompt_guard = "prompt_guard"
|
injection_shield = "injection_shield"
|
||||||
|
jailbreak_shield = "jailbreak_shield"
|
||||||
code_scanner_guard = "code_scanner_guard"
|
code_scanner_guard = "code_scanner_guard"
|
||||||
third_party_shield = "third_party_shield"
|
third_party_shield = "third_party_shield"
|
||||||
injection_shield = "injection_shield"
|
injection_shield = "injection_shield"
|
||||||
|
|
|
@ -11,7 +11,11 @@ from .base import ( # noqa: F401
|
||||||
from .code_scanner import CodeScannerShield # noqa: F401
|
from .code_scanner import CodeScannerShield # noqa: F401
|
||||||
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
||||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401
|
from .prompt_guard import ( # noqa: F401
|
||||||
|
InjectionShield,
|
||||||
|
JailbreakShield,
|
||||||
|
PromptGuardShield,
|
||||||
|
)
|
||||||
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue