From a2465f3f9c33aac0dd044f5a2868d79bb2f70eda Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 24 Sep 2024 19:20:26 -0700 Subject: [PATCH] Revert parts of 0d2eb3bd258d131ddd33cbfffd73cf4d631c458d --- .../impls/meta_reference/safety/__init__.py | 4 +- .../impls/meta_reference/safety/safety.py | 28 +++---- .../safety/shields/llama_guard.py | 80 +++++++++++++------ llama_stack/providers/registry/safety.py | 6 +- 4 files changed, 75 insertions(+), 43 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/safety/__init__.py b/llama_stack/providers/impls/meta_reference/safety/__init__.py index 6c686120c..ad175ce46 100644 --- a/llama_stack/providers/impls/meta_reference/safety/__init__.py +++ b/llama_stack/providers/impls/meta_reference/safety/__init__.py @@ -7,11 +7,11 @@ from .config import SafetyConfig -async def get_provider_impl(config: SafetyConfig, deps): +async def get_provider_impl(config: SafetyConfig, _deps): from .safety import MetaReferenceSafetyImpl assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" - impl = MetaReferenceSafetyImpl(config, deps) + impl = MetaReferenceSafetyImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 6bb851596..6cf8a79d2 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -7,10 +7,8 @@ from llama_models.sku_list import resolve_model from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import Api from llama_stack.providers.impls.meta_reference.safety.shields.base import ( OnViolationAction, @@ -36,11 +34,20 @@ def resolve_and_get_path(model_name: str) -> str: class MetaReferenceSafetyImpl(Safety): - def __init__(self, config: SafetyConfig, deps) -> None: + def __init__(self, config: SafetyConfig) -> None: self.config = config - self.inference_api = deps[Api.inference] async def initialize(self) -> None: + shield_cfg = self.config.llama_guard_shield + if shield_cfg is not None: + model_dir = resolve_and_get_path(shield_cfg.model) + _ = LlamaGuardShield.instance( + model_dir=model_dir, + excluded_categories=shield_cfg.excluded_categories, + disable_input_check=shield_cfg.disable_input_check, + disable_output_check=shield_cfg.disable_output_check, + ) + shield_cfg = self.config.prompt_guard_shield if shield_cfg is not None: model_dir = resolve_and_get_path(shield_cfg.model) @@ -84,18 +91,11 @@ class MetaReferenceSafetyImpl(Safety): def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: cfg = self.config if typ == MetaReferenceShieldType.llama_guard: - cfg = cfg.llama_guard_shield assert ( - cfg is not None + cfg.llama_guard_shield is not None ), "Cannot use LlamaGuardShield since not present in config" - - return LlamaGuardShield( - model=cfg.model, - inference_api=self.inference_api, - excluded_categories=cfg.excluded_categories, - disable_input_check=cfg.disable_input_check, - disable_output_check=cfg.disable_output_check, - ) + model_dir = resolve_and_get_path(cfg.llama_guard_shield.model) + return LlamaGuardShield.instance(model_dir=model_dir) elif typ == MetaReferenceShieldType.jailbreak_shield: assert ( cfg.prompt_guard_shield is not None diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index 0f252e5c3..c29361b95 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -9,8 +9,9 @@ import re from string import Template from typing import List, Optional +import torch from llama_models.llama3.api.datatypes import Message, Role -from llama_stack.apis.inference import * # noqa: F403 +from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse @@ -99,17 +100,39 @@ PROMPT_TEMPLATE = Template( class LlamaGuardShield(ShieldBase): - def __init__( - self, - model: str, - inference_api: Inference, + @staticmethod + def instance( + on_violation_action=OnViolationAction.RAISE, + model_dir: str = None, excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, + ) -> "LlamaGuardShield": + global _INSTANCE + if _INSTANCE is None: + _INSTANCE = LlamaGuardShield( + on_violation_action, + model_dir, + excluded_categories, + disable_input_check, + disable_output_check, + ) + return _INSTANCE + + def __init__( + self, on_violation_action: OnViolationAction = OnViolationAction.RAISE, + model_dir: str = None, + excluded_categories: List[str] = None, + disable_input_check: bool = False, + disable_output_check: bool = False, ): super().__init__(on_violation_action) + dtype = torch.bfloat16 + + assert model_dir is not None, "Llama Guard model_dir is None" + if excluded_categories is None: excluded_categories = [] @@ -117,12 +140,18 @@ class LlamaGuardShield(ShieldBase): x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" - self.model = model - self.inference_api = inference_api + self.device = "cuda" self.excluded_categories = excluded_categories self.disable_input_check = disable_input_check self.disable_output_check = disable_output_check + # load model + torch_dtype = torch.bfloat16 + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + self.model = AutoModelForCausalLM.from_pretrained( + model_dir, torch_dtype=torch_dtype, device_map=self.device + ) + def check_unsafe_response(self, response: str) -> Optional[str]: match = re.match(r"^unsafe\n(.*)$", response) if match: @@ -183,21 +212,26 @@ class LlamaGuardShield(ShieldBase): ) else: prompt = self.build_prompt(messages) + llama_guard_input = { + "role": "user", + "content": prompt, + } + input_ids = self.tokenizer.apply_chat_template( + [llama_guard_input], return_tensors="pt", tokenize=True + ).to(self.device) + prompt_len = input_ids.shape[1] + output = self.model.generate( + input_ids=input_ids, + max_new_tokens=20, + output_scores=True, + return_dict_in_generate=True, + pad_token_id=0, + ) + generated_tokens = output.sequences[:, prompt_len:] - # TODO: llama-stack inference protocol has issues with non-streaming inference code - content = "" - async for chunk in self.inference_api.chat_completion( - model=self.model, - messages=[ - UserMessage(content=prompt), - ], - stream=True, - ): - event = chunk.event - if event.event_type == ChatCompletionResponseEventType.progress: - assert isinstance(event.delta, str) - content += event.delta - - content = content.strip() - shield_response = self.get_shield_response(content) + response = self.tokenizer.decode( + generated_tokens[0], skip_special_tokens=True + ) + response = response.strip() + shield_response = self.get_shield_response(response) return shield_response diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 202690264..09aed4982 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -21,15 +21,13 @@ def available_providers() -> List[ProviderSpec]: api=Api.safety, provider_id="meta-reference", pip_packages=[ + "accelerate", "codeshield", + "torch", "transformers", - "torch --index-url https://download.pytorch.org/whl/cpu", ], module="llama_stack.providers.impls.meta_reference.safety", config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", - api_dependencies=[ - Api.inference, - ], ), remote_provider_spec( api=Api.safety,