forked from phoenix-oss/llama-stack-mirror
Use inference APIs for running llama guard
Test Plan: First, start a TGI container with `meta-llama/Llama-Guard-3-8B` model serving on port 5099. See https://github.com/meta-llama/llama-stack/pull/53 and its description for how. Then run llama-stack with the following run config: ``` image_name: safety docker_image: null conda_env: safety apis_to_serve: - models - inference - shields - safety api_providers: inference: providers: - remote::tgi safety: providers: - meta-reference telemetry: provider_id: meta-reference config: {} routing_table: inference: - provider_id: remote::tgi config: url: http://localhost:5099 api_token: null hf_endpoint_name: null routing_key: Llama-Guard-3-8B safety: - provider_id: meta-reference config: llama_guard_shield: model: Llama-Guard-3-8B excluded_categories: [] disable_input_check: false disable_output_check: false prompt_guard_shield: null routing_key: llama_guard ``` Now simply run `python -m llama_stack.apis.safety.client localhost <port>` and check that the llama_guard shield calls run correctly. (The injection_shield calls fail as expected since we have not set up a router for them.)
This commit is contained in:
parent
c4534217c8
commit
0d2eb3bd25
9 changed files with 56 additions and 81 deletions
|
@ -7,11 +7,11 @@
|
|||
from .config import SafetyConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SafetyConfig, _deps):
|
||||
async def get_provider_impl(config: SafetyConfig, deps):
|
||||
from .safety import MetaReferenceSafetyImpl
|
||||
|
||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceSafetyImpl(config)
|
||||
impl = MetaReferenceSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -7,8 +7,10 @@
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
|
@ -34,20 +36,11 @@ def resolve_and_get_path(model_name: str) -> str:
|
|||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
def __init__(self, config: SafetyConfig) -> None:
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
shield_cfg = self.config.llama_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
_ = LlamaGuardShield.instance(
|
||||
model_dir=model_dir,
|
||||
excluded_categories=shield_cfg.excluded_categories,
|
||||
disable_input_check=shield_cfg.disable_input_check,
|
||||
disable_output_check=shield_cfg.disable_output_check,
|
||||
)
|
||||
|
||||
shield_cfg = self.config.prompt_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
|
@ -91,11 +84,18 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
cfg = self.config
|
||||
if typ == MetaReferenceShieldType.llama_guard:
|
||||
cfg = cfg.llama_guard_shield
|
||||
assert (
|
||||
cfg.llama_guard_shield is not None
|
||||
cfg is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
|
||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
||||
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
disable_input_check=cfg.disable_input_check,
|
||||
disable_output_check=cfg.disable_output_check,
|
||||
)
|
||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||
assert (
|
||||
cfg.prompt_guard_shield is not None
|
||||
|
|
|
@ -9,9 +9,8 @@ import re
|
|||
from string import Template
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from llama_models.llama3.api.datatypes import Message, Role
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
|
||||
|
@ -100,39 +99,17 @@ PROMPT_TEMPLATE = Template(
|
|||
|
||||
|
||||
class LlamaGuardShield(ShieldBase):
|
||||
@staticmethod
|
||||
def instance(
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
) -> "LlamaGuardShield":
|
||||
global _INSTANCE
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = LlamaGuardShield(
|
||||
on_violation_action,
|
||||
model_dir,
|
||||
excluded_categories,
|
||||
disable_input_check,
|
||||
disable_output_check,
|
||||
)
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
||||
dtype = torch.bfloat16
|
||||
|
||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
|
||||
|
@ -140,18 +117,12 @@ class LlamaGuardShield(ShieldBase):
|
|||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
||||
self.device = "cuda"
|
||||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
# load model
|
||||
torch_dtype = torch.bfloat16
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||
)
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
|
@ -212,26 +183,21 @@ class LlamaGuardShield(ShieldBase):
|
|||
)
|
||||
else:
|
||||
prompt = self.build_prompt(messages)
|
||||
llama_guard_input = {
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
[llama_guard_input], return_tensors="pt", tokenize=True
|
||||
).to(self.device)
|
||||
prompt_len = input_ids.shape[1]
|
||||
output = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=20,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=0,
|
||||
)
|
||||
generated_tokens = output.sequences[:, prompt_len:]
|
||||
|
||||
response = self.tokenizer.decode(
|
||||
generated_tokens[0], skip_special_tokens=True
|
||||
)
|
||||
response = response.strip()
|
||||
shield_response = self.get_shield_response(response)
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
content = ""
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
model=self.model,
|
||||
messages=[
|
||||
UserMessage(content=prompt),
|
||||
],
|
||||
stream=True,
|
||||
):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.progress:
|
||||
assert isinstance(event.delta, str)
|
||||
content += event.delta
|
||||
|
||||
content = content.strip()
|
||||
shield_response = self.get_shield_response(content)
|
||||
return shield_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue