mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
add register in meta reference
This commit is contained in:
parent
b9474144b6
commit
d66293d498
1 changed files with 6 additions and 5 deletions
|
@ -30,9 +30,9 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
||||||
self.available_shields = []
|
self.available_shields = []
|
||||||
if config.llama_guard_shield:
|
if config.llama_guard_shield:
|
||||||
self.available_shields.append(ShieldType.llama_guard.value)
|
self.available_shields.append(ShieldType.llama_guard)
|
||||||
if config.enable_prompt_guard:
|
if config.enable_prompt_guard:
|
||||||
self.available_shields.append(ShieldType.prompt_guard.value)
|
self.available_shields.append(ShieldType.prompt_guard)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if self.config.enable_prompt_guard:
|
if self.config.enable_prompt_guard:
|
||||||
|
@ -43,7 +43,8 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
raise ValueError("Registering dynamic shields is not supported")
|
if shield.shield_type not in self.available_shields:
|
||||||
|
raise ValueError(f"Shield type {shield.shield_type} not supported")
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -79,14 +80,14 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
def get_shield_impl(self, shield: Shield) -> ShieldBase:
|
def get_shield_impl(self, shield: Shield) -> ShieldBase:
|
||||||
if shield.shield_type == ShieldType.llama_guard.value:
|
if shield.shield_type == ShieldType.llama_guard:
|
||||||
cfg = self.config.llama_guard_shield
|
cfg = self.config.llama_guard_shield
|
||||||
return LlamaGuardShield(
|
return LlamaGuardShield(
|
||||||
model=cfg.model,
|
model=cfg.model,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
excluded_categories=cfg.excluded_categories,
|
excluded_categories=cfg.excluded_categories,
|
||||||
)
|
)
|
||||||
elif shield.shield_type == ShieldType.prompt_guard.value:
|
elif shield.shield_type == ShieldType.prompt_guard:
|
||||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||||
if subtype == "injection":
|
if subtype == "injection":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue