mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Push registration methods onto the backing providers
This commit is contained in:
parent
5a7b01d292
commit
4215cc9331
14 changed files with 269 additions and 220 deletions
|
|
@ -10,23 +10,36 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api, RoutableProvider
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceShieldType, SafetyConfig
|
||||
from .config import SafetyConfig
|
||||
|
||||
from .shields import CodeScannerShield, LlamaGuardShield, ShieldBase
|
||||
from .shields import (
|
||||
CodeScannerShield,
|
||||
InjectionShield,
|
||||
JailbreakShield,
|
||||
LlamaGuardShield,
|
||||
ShieldBase,
|
||||
)
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
self.registered_shields = []
|
||||
|
||||
self.available_shields = [ShieldType.code_scanner.value]
|
||||
if config.llama_guard_shield:
|
||||
self.available_shields.append(ShieldType.llama_guard.value)
|
||||
if config.enable_prompt_guard:
|
||||
self.available_shields.append(ShieldType.prompt_guard.value)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if self.config.enable_prompt_guard:
|
||||
|
|
@ -38,11 +51,20 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
for key in routing_keys:
|
||||
if key not in available_shields:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type not in self.available_shields:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
self.registered_shields.append(shield)
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return self.registered_shields
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
|
||||
for shield in self.registered_shields:
|
||||
if shield.identifier == identifier:
|
||||
return shield
|
||||
return None
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
@ -50,10 +72,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
assert shield_type in available_shields, f"Unknown shield {shield_type}"
|
||||
shield_def = await self.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
|
||||
shield = self.get_shield_impl(shield_def)
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
|
|
@ -79,30 +102,24 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
cfg = self.config
|
||||
if typ == MetaReferenceShieldType.llama_guard:
|
||||
cfg = cfg.llama_guard_shield
|
||||
assert (
|
||||
cfg is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
|
||||
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
||||
if shield.type == ShieldType.llama_guard.value:
|
||||
cfg = self.config.llama_guard_shield
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
)
|
||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||
from .shields import JailbreakShield
|
||||
|
||||
elif shield.type == ShieldType.prompt_guard.value:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
return JailbreakShield.instance(model_dir)
|
||||
elif typ == MetaReferenceShieldType.injection_shield:
|
||||
from .shields import InjectionShield
|
||||
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif typ == MetaReferenceShieldType.code_scanner_guard:
|
||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||
if subtype == "injection":
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif subtype == "jailbreak":
|
||||
return JailbreakShield.instance(model_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||
elif shield.type == ShieldType.code_scanner.value:
|
||||
return CodeScannerShield.instance()
|
||||
else:
|
||||
raise ValueError(f"Unknown shield type: {typ}")
|
||||
raise ValueError(f"Unknown shield type: {shield.type}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue