mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
Changes from the main repo
This commit is contained in:
parent
9c9b834c0f
commit
7d2c0b14b8
8 changed files with 24 additions and 9 deletions
9
toolchain/configs/cyni.yaml
Normal file
9
toolchain/configs/cyni.yaml
Normal file
|
@ -0,0 +1,9 @@
|
|||
model_inference_config:
|
||||
impl_type: "inline"
|
||||
inline_config:
|
||||
checkpoint_type: "pytorch"
|
||||
checkpoint_dir: /home/cyni/local/llama-3
|
||||
tokenizer_path: /home/cyni/local/llama-3/cl_toplang_128k
|
||||
model_parallel_size: 1
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 1
|
|
@ -1,10 +1,11 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from models.llama3_1.api.datatypes import URL
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from strong_typing.schema import json_schema_type
|
||||
from models.llama3_1.api.datatypes import URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||
|
||||
from pyopenapi import webmethod
|
||||
|
||||
from models.llama3_1.api.datatypes import * # noqa: F403
|
||||
from models.llama3_1.api.datatypes import * # noqa: F403
|
||||
from .datatypes import * # noqa: F403
|
||||
from toolchain.dataset.api.datatypes import * # noqa: F403
|
||||
from toolchain.common.training_types import * # noqa: F403
|
||||
|
|
|
@ -93,7 +93,7 @@ class Llama:
|
|||
checkpoints
|
||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
|
|
|
@ -13,7 +13,8 @@ from toolchain.common.deployment_types import RestAPIExecutionConfig
|
|||
@json_schema_type
|
||||
class BuiltinShield(Enum):
|
||||
llama_guard = "llama_guard"
|
||||
prompt_guard = "prompt_guard"
|
||||
injection_shield = "injection_shield"
|
||||
jailbreak_shield = "jailbreak_shield"
|
||||
code_scanner_guard = "code_scanner_guard"
|
||||
third_party_shield = "third_party_shield"
|
||||
injection_shield = "injection_shield"
|
||||
|
|
|
@ -11,7 +11,11 @@ from .base import ( # noqa: F401
|
|||
from .code_scanner import CodeScannerShield # noqa: F401
|
||||
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401
|
||||
from .prompt_guard import ( # noqa: F401
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
PromptGuardShield,
|
||||
)
|
||||
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#!/bin/bash
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
@ -10,10 +10,10 @@ rootdir=$(git rev-parse --show-toplevel)
|
|||
files_to_copy=("toolchain/spec/openapi*" "models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py")
|
||||
for file in "${files_to_copy[@]}"; do
|
||||
relpath="$file"
|
||||
set -x
|
||||
set -x
|
||||
mkdir -p "$TMPDIR/$(dirname $relpath)"
|
||||
eval cp "$rootdir/$relpath" "$TMPDIR/$(dirname $relpath)"
|
||||
set +x
|
||||
set +x
|
||||
done
|
||||
|
||||
cd "$TMPDIR"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -x
|
||||
set -x
|
||||
|
||||
PYTHONPATH=../../../oss-ops:../.. python3 -m toolchain.spec.generate
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue