forked from phoenix-oss/llama-stack-mirror
Revert parts of 0d2eb3bd25
This commit is contained in:
parent
059e50b389
commit
a2465f3f9c
4 changed files with 75 additions and 43 deletions
|
@ -7,11 +7,11 @@
|
||||||
from .config import SafetyConfig
|
from .config import SafetyConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: SafetyConfig, deps):
|
async def get_provider_impl(config: SafetyConfig, _deps):
|
||||||
from .safety import MetaReferenceSafetyImpl
|
from .safety import MetaReferenceSafetyImpl
|
||||||
|
|
||||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = MetaReferenceSafetyImpl(config, deps)
|
impl = MetaReferenceSafetyImpl(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -7,10 +7,8 @@
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import Api
|
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
|
@ -36,11 +34,20 @@ def resolve_and_get_path(model_name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety):
|
class MetaReferenceSafetyImpl(Safety):
|
||||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
def __init__(self, config: SafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = deps[Api.inference]
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
shield_cfg = self.config.llama_guard_shield
|
||||||
|
if shield_cfg is not None:
|
||||||
|
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||||
|
_ = LlamaGuardShield.instance(
|
||||||
|
model_dir=model_dir,
|
||||||
|
excluded_categories=shield_cfg.excluded_categories,
|
||||||
|
disable_input_check=shield_cfg.disable_input_check,
|
||||||
|
disable_output_check=shield_cfg.disable_output_check,
|
||||||
|
)
|
||||||
|
|
||||||
shield_cfg = self.config.prompt_guard_shield
|
shield_cfg = self.config.prompt_guard_shield
|
||||||
if shield_cfg is not None:
|
if shield_cfg is not None:
|
||||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||||
|
@ -84,18 +91,11 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
if typ == MetaReferenceShieldType.llama_guard:
|
if typ == MetaReferenceShieldType.llama_guard:
|
||||||
cfg = cfg.llama_guard_shield
|
|
||||||
assert (
|
assert (
|
||||||
cfg is not None
|
cfg.llama_guard_shield is not None
|
||||||
), "Cannot use LlamaGuardShield since not present in config"
|
), "Cannot use LlamaGuardShield since not present in config"
|
||||||
|
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
|
||||||
return LlamaGuardShield(
|
return LlamaGuardShield.instance(model_dir=model_dir)
|
||||||
model=cfg.model,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
excluded_categories=cfg.excluded_categories,
|
|
||||||
disable_input_check=cfg.disable_input_check,
|
|
||||||
disable_output_check=cfg.disable_output_check,
|
|
||||||
)
|
|
||||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||||
assert (
|
assert (
|
||||||
cfg.prompt_guard_shield is not None
|
cfg.prompt_guard_shield is not None
|
||||||
|
|
|
@ -9,8 +9,9 @@ import re
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from llama_models.llama3.api.datatypes import Message, Role
|
from llama_models.llama3.api.datatypes import Message, Role
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
|
||||||
|
@ -99,17 +100,39 @@ PROMPT_TEMPLATE = Template(
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShield(ShieldBase):
|
class LlamaGuardShield(ShieldBase):
|
||||||
def __init__(
|
@staticmethod
|
||||||
self,
|
def instance(
|
||||||
model: str,
|
on_violation_action=OnViolationAction.RAISE,
|
||||||
inference_api: Inference,
|
model_dir: str = None,
|
||||||
excluded_categories: List[str] = None,
|
excluded_categories: List[str] = None,
|
||||||
disable_input_check: bool = False,
|
disable_input_check: bool = False,
|
||||||
disable_output_check: bool = False,
|
disable_output_check: bool = False,
|
||||||
|
) -> "LlamaGuardShield":
|
||||||
|
global _INSTANCE
|
||||||
|
if _INSTANCE is None:
|
||||||
|
_INSTANCE = LlamaGuardShield(
|
||||||
|
on_violation_action,
|
||||||
|
model_dir,
|
||||||
|
excluded_categories,
|
||||||
|
disable_input_check,
|
||||||
|
disable_output_check,
|
||||||
|
)
|
||||||
|
return _INSTANCE
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
|
model_dir: str = None,
|
||||||
|
excluded_categories: List[str] = None,
|
||||||
|
disable_input_check: bool = False,
|
||||||
|
disable_output_check: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(on_violation_action)
|
super().__init__(on_violation_action)
|
||||||
|
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||||
|
|
||||||
if excluded_categories is None:
|
if excluded_categories is None:
|
||||||
excluded_categories = []
|
excluded_categories = []
|
||||||
|
|
||||||
|
@ -117,12 +140,18 @@ class LlamaGuardShield(ShieldBase):
|
||||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||||
|
|
||||||
self.model = model
|
self.device = "cuda"
|
||||||
self.inference_api = inference_api
|
|
||||||
self.excluded_categories = excluded_categories
|
self.excluded_categories = excluded_categories
|
||||||
self.disable_input_check = disable_input_check
|
self.disable_input_check = disable_input_check
|
||||||
self.disable_output_check = disable_output_check
|
self.disable_output_check = disable_output_check
|
||||||
|
|
||||||
|
# load model
|
||||||
|
torch_dtype = torch.bfloat16
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||||
|
)
|
||||||
|
|
||||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||||
match = re.match(r"^unsafe\n(.*)$", response)
|
match = re.match(r"^unsafe\n(.*)$", response)
|
||||||
if match:
|
if match:
|
||||||
|
@ -183,21 +212,26 @@ class LlamaGuardShield(ShieldBase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = self.build_prompt(messages)
|
prompt = self.build_prompt(messages)
|
||||||
|
llama_guard_input = {
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
input_ids = self.tokenizer.apply_chat_template(
|
||||||
|
[llama_guard_input], return_tensors="pt", tokenize=True
|
||||||
|
).to(self.device)
|
||||||
|
prompt_len = input_ids.shape[1]
|
||||||
|
output = self.model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_new_tokens=20,
|
||||||
|
output_scores=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
)
|
||||||
|
generated_tokens = output.sequences[:, prompt_len:]
|
||||||
|
|
||||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
response = self.tokenizer.decode(
|
||||||
content = ""
|
generated_tokens[0], skip_special_tokens=True
|
||||||
async for chunk in self.inference_api.chat_completion(
|
)
|
||||||
model=self.model,
|
response = response.strip()
|
||||||
messages=[
|
shield_response = self.get_shield_response(response)
|
||||||
UserMessage(content=prompt),
|
|
||||||
],
|
|
||||||
stream=True,
|
|
||||||
):
|
|
||||||
event = chunk.event
|
|
||||||
if event.event_type == ChatCompletionResponseEventType.progress:
|
|
||||||
assert isinstance(event.delta, str)
|
|
||||||
content += event.delta
|
|
||||||
|
|
||||||
content = content.strip()
|
|
||||||
shield_response = self.get_shield_response(content)
|
|
||||||
return shield_response
|
return shield_response
|
||||||
|
|
|
@ -21,15 +21,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
|
"accelerate",
|
||||||
"codeshield",
|
"codeshield",
|
||||||
|
"torch",
|
||||||
"transformers",
|
"transformers",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.impls.meta_reference.safety",
|
module="llama_stack.providers.impls.meta_reference.safety",
|
||||||
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
||||||
api_dependencies=[
|
|
||||||
Api.inference,
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue