Revert parts of 0d2eb3bd25

This commit is contained in:
Ashwin Bharambe 2024-09-24 19:20:26 -07:00
parent 059e50b389
commit a2465f3f9c
4 changed files with 75 additions and 43 deletions

View file

@ -7,11 +7,11 @@
from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, deps):
async def get_provider_impl(config: SafetyConfig, _deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config, deps)
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl

View file

@ -7,10 +7,8 @@
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
@ -36,11 +34,20 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig, deps) -> None:
def __init__(self, config: SafetyConfig) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance(
model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check,
)
shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
@ -84,18 +91,11 @@ class MetaReferenceSafetyImpl(Safety):
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard:
cfg = cfg.llama_guard_shield
assert (
cfg is not None
cfg.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
)
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif typ == MetaReferenceShieldType.jailbreak_shield:
assert (
cfg.prompt_guard_shield is not None

View file

@ -9,8 +9,9 @@ import re
from string import Template
from typing import List, Optional
import torch
from llama_models.llama3.api.datatypes import Message, Role
from llama_stack.apis.inference import * # noqa: F403
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -99,17 +100,39 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase):
def __init__(
self,
model: str,
inference_api: Inference,
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
@ -117,12 +140,18 @@ class LlamaGuardShield(ShieldBase):
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.model = model
self.inference_api = inference_api
self.device = "cuda"
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
@ -183,21 +212,26 @@ class LlamaGuardShield(ShieldBase):
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in self.inference_api.chat_completion(
model=self.model,
messages=[
UserMessage(content=prompt),
],
stream=True,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.progress:
assert isinstance(event.delta, str)
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response)
return shield_response

View file

@ -21,15 +21,13 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety,
provider_id="meta-reference",
pip_packages=[
"accelerate",
"codeshield",
"torch",
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
api_dependencies=[
Api.inference,
],
),
remote_provider_spec(
api=Api.safety,