Revert parts of 0d2eb3bd25

This commit is contained in:
Ashwin Bharambe 2024-09-24 19:20:26 -07:00
parent 059e50b389
commit a2465f3f9c
4 changed files with 75 additions and 43 deletions

View file

@ -7,11 +7,11 @@
from .config import SafetyConfig from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, deps): async def get_provider_impl(config: SafetyConfig, _deps):
from .safety import MetaReferenceSafetyImpl from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config, deps) impl = MetaReferenceSafetyImpl(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -7,10 +7,8 @@
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.impls.meta_reference.safety.shields.base import ( from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction, OnViolationAction,
@ -36,11 +34,20 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety): class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig, deps) -> None: def __init__(self, config: SafetyConfig) -> None:
self.config = config self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None: async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance(
model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check,
)
shield_cfg = self.config.prompt_guard_shield shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None: if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model) model_dir = resolve_and_get_path(shield_cfg.model)
@ -84,18 +91,11 @@ class MetaReferenceSafetyImpl(Safety):
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config cfg = self.config
if typ == MetaReferenceShieldType.llama_guard: if typ == MetaReferenceShieldType.llama_guard:
cfg = cfg.llama_guard_shield
assert ( assert (
cfg is not None cfg.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config" ), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
return LlamaGuardShield( return LlamaGuardShield.instance(model_dir=model_dir)
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
)
elif typ == MetaReferenceShieldType.jailbreak_shield: elif typ == MetaReferenceShieldType.jailbreak_shield:
assert ( assert (
cfg.prompt_guard_shield is not None cfg.prompt_guard_shield is not None

View file

@ -9,8 +9,9 @@ import re
from string import Template from string import Template
from typing import List, Optional from typing import List, Optional
import torch
from llama_models.llama3.api.datatypes import Message, Role from llama_models.llama3.api.datatypes import Message, Role
from llama_stack.apis.inference import * # noqa: F403 from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -99,17 +100,39 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase): class LlamaGuardShield(ShieldBase):
def __init__( @staticmethod
self, def instance(
model: str, on_violation_action=OnViolationAction.RAISE,
inference_api: Inference, model_dir: str = None,
excluded_categories: List[str] = None, excluded_categories: List[str] = None,
disable_input_check: bool = False, disable_input_check: bool = False,
disable_output_check: bool = False, disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE, on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
): ):
super().__init__(on_violation_action) super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None: if excluded_categories is None:
excluded_categories = [] excluded_categories = []
@ -117,12 +140,18 @@ class LlamaGuardShield(ShieldBase):
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.model = model self.device = "cuda"
self.inference_api = inference_api
self.excluded_categories = excluded_categories self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def check_unsafe_response(self, response: str) -> Optional[str]: def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
if match: if match:
@ -183,21 +212,26 @@ class LlamaGuardShield(ShieldBase):
) )
else: else:
prompt = self.build_prompt(messages) prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
# TODO: llama-stack inference protocol has issues with non-streaming inference code response = self.tokenizer.decode(
content = "" generated_tokens[0], skip_special_tokens=True
async for chunk in self.inference_api.chat_completion( )
model=self.model, response = response.strip()
messages=[ shield_response = self.get_shield_response(response)
UserMessage(content=prompt),
],
stream=True,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.progress:
assert isinstance(event.delta, str)
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response return shield_response

View file

@ -21,15 +21,13 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety, api=Api.safety,
provider_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"accelerate",
"codeshield", "codeshield",
"torch",
"transformers", "transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
], ],
module="llama_stack.providers.impls.meta_reference.safety", module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
api_dependencies=[
Api.inference,
],
), ),
remote_provider_spec( remote_provider_spec(
api=Api.safety, api=Api.safety,