mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Use inference APIs for running llama guard
Test Plan: First, start a TGI container with `meta-llama/Llama-Guard-3-8B` model serving on port 5099. See https://github.com/meta-llama/llama-stack/pull/53 and its description for how. Then run llama-stack with the following run config: ``` image_name: safety docker_image: null conda_env: safety apis_to_serve: - models - inference - shields - safety api_providers: inference: providers: - remote::tgi safety: providers: - meta-reference telemetry: provider_id: meta-reference config: {} routing_table: inference: - provider_id: remote::tgi config: url: http://localhost:5099 api_token: null hf_endpoint_name: null routing_key: Llama-Guard-3-8B safety: - provider_id: meta-reference config: llama_guard_shield: model: Llama-Guard-3-8B excluded_categories: [] disable_input_check: false disable_output_check: false prompt_guard_shield: null routing_key: llama_guard ``` Now simply run `python -m llama_stack.apis.safety.client localhost <port>` and check that the llama_guard shield calls run correctly. (The injection_shield calls fail as expected since we have not set up a router for them.)
This commit is contained in:
parent
c4534217c8
commit
0d2eb3bd25
9 changed files with 56 additions and 81 deletions
|
@ -190,7 +190,7 @@ class Inference(Protocol):
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = list,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
|
|
@ -103,8 +103,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
params = dict(
|
||||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
@ -113,6 +112,10 @@ class InferenceRouter(Inference):
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||||
|
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||||
|
**params
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
for p in self.providers.values():
|
for p in self.providers.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str) -> Optional[Any]:
|
def get_provider_impl(self, routing_key: str) -> Any:
|
||||||
return self.providers.get(routing_key)
|
if routing_key not in self.providers:
|
||||||
|
raise ValueError(f"Could not find provider for {routing_key}")
|
||||||
|
return self.providers[routing_key]
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
def get_routing_keys(self) -> List[str]:
|
||||||
return self.routing_keys
|
return self.routing_keys
|
||||||
|
|
|
@ -368,17 +368,19 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
providers = all_providers[info.router_api]
|
providers = all_providers[info.router_api]
|
||||||
|
|
||||||
inner_specs = []
|
inner_specs = []
|
||||||
|
inner_deps = []
|
||||||
for rt_entry in routing_table:
|
for rt_entry in routing_table:
|
||||||
if rt_entry.provider_id not in providers:
|
if rt_entry.provider_id not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
inner_specs.append(providers[rt_entry.provider_id])
|
inner_specs.append(providers[rt_entry.provider_id])
|
||||||
|
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
|
||||||
|
|
||||||
specs[source_api] = RoutingTableProviderSpec(
|
specs[source_api] = RoutingTableProviderSpec(
|
||||||
api=source_api,
|
api=source_api,
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
api_dependencies=[],
|
api_dependencies=inner_deps,
|
||||||
inner_specs=inner_specs,
|
inner_specs=inner_specs,
|
||||||
)
|
)
|
||||||
configs[source_api] = routing_table
|
configs[source_api] = routing_table
|
||||||
|
|
|
@ -119,7 +119,7 @@ class TGIAdapter(Inference):
|
||||||
)
|
)
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
if response.details.finish_reason:
|
if response.details.finish_reason:
|
||||||
if response.details.finish_reason == "stop":
|
if response.details.finish_reason in ["stop", "eos_token"]:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif response.details.finish_reason == "length":
|
elif response.details.finish_reason == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
|
@ -7,11 +7,11 @@
|
||||||
from .config import SafetyConfig
|
from .config import SafetyConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: SafetyConfig, _deps):
|
async def get_provider_impl(config: SafetyConfig, deps):
|
||||||
from .safety import MetaReferenceSafetyImpl
|
from .safety import MetaReferenceSafetyImpl
|
||||||
|
|
||||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = MetaReferenceSafetyImpl(config)
|
impl = MetaReferenceSafetyImpl(config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -7,8 +7,10 @@
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
|
@ -34,20 +36,11 @@ def resolve_and_get_path(model_name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety):
|
class MetaReferenceSafetyImpl(Safety):
|
||||||
def __init__(self, config: SafetyConfig) -> None:
|
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = deps[Api.inference]
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
shield_cfg = self.config.llama_guard_shield
|
|
||||||
if shield_cfg is not None:
|
|
||||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
|
||||||
_ = LlamaGuardShield.instance(
|
|
||||||
model_dir=model_dir,
|
|
||||||
excluded_categories=shield_cfg.excluded_categories,
|
|
||||||
disable_input_check=shield_cfg.disable_input_check,
|
|
||||||
disable_output_check=shield_cfg.disable_output_check,
|
|
||||||
)
|
|
||||||
|
|
||||||
shield_cfg = self.config.prompt_guard_shield
|
shield_cfg = self.config.prompt_guard_shield
|
||||||
if shield_cfg is not None:
|
if shield_cfg is not None:
|
||||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||||
|
@ -91,11 +84,18 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
if typ == MetaReferenceShieldType.llama_guard:
|
if typ == MetaReferenceShieldType.llama_guard:
|
||||||
|
cfg = cfg.llama_guard_shield
|
||||||
assert (
|
assert (
|
||||||
cfg.llama_guard_shield is not None
|
cfg is not None
|
||||||
), "Cannot use LlamaGuardShield since not present in config"
|
), "Cannot use LlamaGuardShield since not present in config"
|
||||||
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
|
|
||||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
return LlamaGuardShield(
|
||||||
|
model=cfg.model,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
excluded_categories=cfg.excluded_categories,
|
||||||
|
disable_input_check=cfg.disable_input_check,
|
||||||
|
disable_output_check=cfg.disable_output_check,
|
||||||
|
)
|
||||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||||
assert (
|
assert (
|
||||||
cfg.prompt_guard_shield is not None
|
cfg.prompt_guard_shield is not None
|
||||||
|
|
|
@ -9,9 +9,8 @@ import re
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
from llama_models.llama3.api.datatypes import Message, Role
|
from llama_models.llama3.api.datatypes import Message, Role
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
|
||||||
|
@ -100,39 +99,17 @@ PROMPT_TEMPLATE = Template(
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShield(ShieldBase):
|
class LlamaGuardShield(ShieldBase):
|
||||||
@staticmethod
|
|
||||||
def instance(
|
|
||||||
on_violation_action=OnViolationAction.RAISE,
|
|
||||||
model_dir: str = None,
|
|
||||||
excluded_categories: List[str] = None,
|
|
||||||
disable_input_check: bool = False,
|
|
||||||
disable_output_check: bool = False,
|
|
||||||
) -> "LlamaGuardShield":
|
|
||||||
global _INSTANCE
|
|
||||||
if _INSTANCE is None:
|
|
||||||
_INSTANCE = LlamaGuardShield(
|
|
||||||
on_violation_action,
|
|
||||||
model_dir,
|
|
||||||
excluded_categories,
|
|
||||||
disable_input_check,
|
|
||||||
disable_output_check,
|
|
||||||
)
|
|
||||||
return _INSTANCE
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
model: str,
|
||||||
model_dir: str = None,
|
inference_api: Inference,
|
||||||
excluded_categories: List[str] = None,
|
excluded_categories: List[str] = None,
|
||||||
disable_input_check: bool = False,
|
disable_input_check: bool = False,
|
||||||
disable_output_check: bool = False,
|
disable_output_check: bool = False,
|
||||||
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
):
|
):
|
||||||
super().__init__(on_violation_action)
|
super().__init__(on_violation_action)
|
||||||
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
|
||||||
|
|
||||||
if excluded_categories is None:
|
if excluded_categories is None:
|
||||||
excluded_categories = []
|
excluded_categories = []
|
||||||
|
|
||||||
|
@ -140,18 +117,12 @@ class LlamaGuardShield(ShieldBase):
|
||||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||||
|
|
||||||
self.device = "cuda"
|
self.model = model
|
||||||
|
self.inference_api = inference_api
|
||||||
self.excluded_categories = excluded_categories
|
self.excluded_categories = excluded_categories
|
||||||
self.disable_input_check = disable_input_check
|
self.disable_input_check = disable_input_check
|
||||||
self.disable_output_check = disable_output_check
|
self.disable_output_check = disable_output_check
|
||||||
|
|
||||||
# load model
|
|
||||||
torch_dtype = torch.bfloat16
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||||
match = re.match(r"^unsafe\n(.*)$", response)
|
match = re.match(r"^unsafe\n(.*)$", response)
|
||||||
if match:
|
if match:
|
||||||
|
@ -212,26 +183,21 @@ class LlamaGuardShield(ShieldBase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = self.build_prompt(messages)
|
prompt = self.build_prompt(messages)
|
||||||
llama_guard_input = {
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
}
|
|
||||||
input_ids = self.tokenizer.apply_chat_template(
|
|
||||||
[llama_guard_input], return_tensors="pt", tokenize=True
|
|
||||||
).to(self.device)
|
|
||||||
prompt_len = input_ids.shape[1]
|
|
||||||
output = self.model.generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
max_new_tokens=20,
|
|
||||||
output_scores=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
pad_token_id=0,
|
|
||||||
)
|
|
||||||
generated_tokens = output.sequences[:, prompt_len:]
|
|
||||||
|
|
||||||
response = self.tokenizer.decode(
|
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||||
generated_tokens[0], skip_special_tokens=True
|
content = ""
|
||||||
)
|
async for chunk in self.inference_api.chat_completion(
|
||||||
response = response.strip()
|
model=self.model,
|
||||||
shield_response = self.get_shield_response(response)
|
messages=[
|
||||||
|
UserMessage(content=prompt),
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
event = chunk.event
|
||||||
|
if event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
|
assert isinstance(event.delta, str)
|
||||||
|
content += event.delta
|
||||||
|
|
||||||
|
content = content.strip()
|
||||||
|
shield_response = self.get_shield_response(content)
|
||||||
return shield_response
|
return shield_response
|
||||||
|
|
|
@ -15,13 +15,15 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"accelerate",
|
|
||||||
"codeshield",
|
"codeshield",
|
||||||
"torch",
|
|
||||||
"transformers",
|
"transformers",
|
||||||
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.impls.meta_reference.safety",
|
module="llama_stack.providers.impls.meta_reference.safety",
|
||||||
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.inference,
|
||||||
|
],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue