mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Changes from the main repo
This commit is contained in:
parent
9c9b834c0f
commit
7d2c0b14b8
8 changed files with 24 additions and 9 deletions
9
toolchain/configs/cyni.yaml
Normal file
9
toolchain/configs/cyni.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
model_inference_config:
|
||||||
|
impl_type: "inline"
|
||||||
|
inline_config:
|
||||||
|
checkpoint_type: "pytorch"
|
||||||
|
checkpoint_dir: /home/cyni/local/llama-3
|
||||||
|
tokenizer_path: /home/cyni/local/llama-3/cl_toplang_128k
|
||||||
|
model_parallel_size: 1
|
||||||
|
max_seq_len: 2048
|
||||||
|
max_batch_size: 1
|
|
@ -1,10 +1,11 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from models.llama3_1.api.datatypes import URL
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from strong_typing.schema import json_schema_type
|
from strong_typing.schema import json_schema_type
|
||||||
from models.llama3_1.api.datatypes import URL
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from pyopenapi import webmethod
|
from pyopenapi import webmethod
|
||||||
|
|
||||||
from models.llama3_1.api.datatypes import * # noqa: F403
|
from models.llama3_1.api.datatypes import * # noqa: F403
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
from toolchain.dataset.api.datatypes import * # noqa: F403
|
from toolchain.dataset.api.datatypes import * # noqa: F403
|
||||||
from toolchain.common.training_types import * # noqa: F403
|
from toolchain.common.training_types import * # noqa: F403
|
||||||
|
|
|
@ -93,7 +93,7 @@ class Llama:
|
||||||
checkpoints
|
checkpoints
|
||||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
params = json.loads(f.read())
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,8 @@ from toolchain.common.deployment_types import RestAPIExecutionConfig
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BuiltinShield(Enum):
|
class BuiltinShield(Enum):
|
||||||
llama_guard = "llama_guard"
|
llama_guard = "llama_guard"
|
||||||
prompt_guard = "prompt_guard"
|
injection_shield = "injection_shield"
|
||||||
|
jailbreak_shield = "jailbreak_shield"
|
||||||
code_scanner_guard = "code_scanner_guard"
|
code_scanner_guard = "code_scanner_guard"
|
||||||
third_party_shield = "third_party_shield"
|
third_party_shield = "third_party_shield"
|
||||||
injection_shield = "injection_shield"
|
injection_shield = "injection_shield"
|
||||||
|
|
|
@ -11,7 +11,11 @@ from .base import ( # noqa: F401
|
||||||
from .code_scanner import CodeScannerShield # noqa: F401
|
from .code_scanner import CodeScannerShield # noqa: F401
|
||||||
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
||||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401
|
from .prompt_guard import ( # noqa: F401
|
||||||
|
InjectionShield,
|
||||||
|
JailbreakShield,
|
||||||
|
PromptGuardShield,
|
||||||
|
)
|
||||||
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
@ -10,10 +10,10 @@ rootdir=$(git rev-parse --show-toplevel)
|
||||||
files_to_copy=("toolchain/spec/openapi*" "models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py")
|
files_to_copy=("toolchain/spec/openapi*" "models.llama3_1.api.datatypes.py" "toolchain/inference/api/*.py" "agentic_system/api/*.py" "toolchain/common/*.py" "toolchain/dataset/api/*.py" "toolchain/evaluations/api/*.py" "toolchain/reward_scoring/api/*.py" "toolchain/post_training/api/*.py" "toolchain/safety/api/*.py")
|
||||||
for file in "${files_to_copy[@]}"; do
|
for file in "${files_to_copy[@]}"; do
|
||||||
relpath="$file"
|
relpath="$file"
|
||||||
set -x
|
set -x
|
||||||
mkdir -p "$TMPDIR/$(dirname $relpath)"
|
mkdir -p "$TMPDIR/$(dirname $relpath)"
|
||||||
eval cp "$rootdir/$relpath" "$TMPDIR/$(dirname $relpath)"
|
eval cp "$rootdir/$relpath" "$TMPDIR/$(dirname $relpath)"
|
||||||
set +x
|
set +x
|
||||||
done
|
done
|
||||||
|
|
||||||
cd "$TMPDIR"
|
cd "$TMPDIR"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
PYTHONPATH=../../../oss-ops:../.. python3 -m toolchain.spec.generate
|
PYTHONPATH=../../../oss-ops:../.. python3 -m toolchain.spec.generate
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue