mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Changes from the main repo
This commit is contained in:
parent
9c9b834c0f
commit
7d2c0b14b8
8 changed files with 24 additions and 9 deletions
9
toolchain/configs/cyni.yaml
Normal file
9
toolchain/configs/cyni.yaml
Normal file
|
@ -0,0 +1,9 @@
|
|||
model_inference_config:
|
||||
impl_type: "inline"
|
||||
inline_config:
|
||||
checkpoint_type: "pytorch"
|
||||
checkpoint_dir: /home/cyni/local/llama-3
|
||||
tokenizer_path: /home/cyni/local/llama-3/cl_toplang_128k
|
||||
model_parallel_size: 1
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 1
|
|
@ -1,10 +1,11 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from models.llama3_1.api.datatypes import URL
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from strong_typing.schema import json_schema_type
|
||||
from models.llama3_1.api.datatypes import URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||
|
||||
from pyopenapi import webmethod
|
||||
|
||||
from models.llama3_1.api.datatypes import * # noqa: F403
|
||||
from models.llama3_1.api.datatypes import * # noqa: F403
|
||||
from .datatypes import * # noqa: F403
|
||||
from toolchain.dataset.api.datatypes import * # noqa: F403
|
||||
from toolchain.common.training_types import * # noqa: F403
|
||||
|
|
|
@ -93,7 +93,7 @@ class Llama:
|
|||
checkpoints
|
||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
|
|
|
@ -13,7 +13,8 @@ from toolchain.common.deployment_types import RestAPIExecutionConfig
|
|||
@json_schema_type
|
||||
class BuiltinShield(Enum):
|
||||
llama_guard = "llama_guard"
|
||||
prompt_guard = "prompt_guard"
|
||||
injection_shield = "injection_shield"
|
||||
jailbreak_shield = "jailbreak_shield"
|
||||
code_scanner_guard = "code_scanner_guard"
|
||||
third_party_shield = "third_party_shield"
|
||||
injection_shield = "injection_shield"
|
||||
|
|
|
@ -11,7 +11,11 @@ from .base import ( # noqa: F401
|
|||
from .code_scanner import CodeScannerShield # noqa: F401
|
||||
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401
|
||||
from .prompt_guard import ( # noqa: F401
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
PromptGuardShield,
|
||||
)
|
||||
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue