mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Update the meta reference safety implementation to match new API
This commit is contained in:
parent
7e40eead4e
commit
82ddd851c8
11 changed files with 115 additions and 130 deletions
|
@ -37,8 +37,8 @@ class AgentTool(Enum):
|
|||
|
||||
|
||||
class ToolDefinitionCommon(BaseModel):
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
|
@ -266,8 +266,8 @@ class Session(BaseModel):
|
|||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
|
|
|
@ -13,11 +13,11 @@ import fire
|
|||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import UserMessage
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from .safety import * # noqa: F403
|
||||
|
||||
|
||||
|
@ -69,11 +69,7 @@ async def run_main(host: str, port: int):
|
|||
response = await client.run_shields(
|
||||
RunShieldRequest(
|
||||
messages=[message],
|
||||
shields=[
|
||||
ShieldDefinition(
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
)
|
||||
],
|
||||
shields=["llama_guard"],
|
||||
)
|
||||
)
|
||||
print(response)
|
||||
|
|
|
@ -37,11 +37,8 @@ class RunShieldResponse(BaseModel):
|
|||
violation: Optional[SafetyViolation] = None
|
||||
|
||||
|
||||
ShieldType = str
|
||||
|
||||
|
||||
class Safety(Protocol):
|
||||
@webmethod(route="/safety/run_shield")
|
||||
async def run_shield(
|
||||
self, shield: ShieldType, messages: List[Message], params: Dict[str, Any] = None
|
||||
self, shield: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
@ -4,51 +4,46 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.safety import (
|
||||
OnViolationAction,
|
||||
Safety,
|
||||
ShieldDefinition,
|
||||
ShieldResponse,
|
||||
)
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
def __init__(self, response: ShieldResponse):
|
||||
self.response = response
|
||||
super().__init__(response.violation_return_message)
|
||||
def __init__(self, violation: SafetyViolation):
|
||||
self.violation = violation
|
||||
super().__init__(violation.user_message)
|
||||
|
||||
|
||||
class ShieldRunnerMixin:
|
||||
def __init__(
|
||||
self,
|
||||
safety_api: Safety,
|
||||
input_shields: List[ShieldDefinition] = None,
|
||||
output_shields: List[ShieldDefinition] = None,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
):
|
||||
self.safety_api = safety_api
|
||||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_shields(
|
||||
self, messages: List[Message], shields: List[ShieldDefinition]
|
||||
) -> List[ShieldResponse]:
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
results = await self.safety_api.run_shields(
|
||||
async def run_shields(self, messages: List[Message], shields: List[str]) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
shield_type=shield_type,
|
||||
messages=messages,
|
||||
shields=shields,
|
||||
)
|
||||
for shield, r in zip(shields, results):
|
||||
if r.is_violation:
|
||||
for shield_type in shields
|
||||
]
|
||||
)
|
||||
|
||||
for shield, r in zip(shields, responses):
|
||||
if r.violation:
|
||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||
raise SafetyException(r)
|
||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
||||
|
@ -56,5 +51,3 @@ class ShieldRunnerMixin:
|
|||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
||||
return results
|
||||
|
|
|
@ -223,7 +223,7 @@ async def test_chat_agent_create_and_execute_turn(chat_agent):
|
|||
@pytest.mark.asyncio
|
||||
async def test_run_shields_wrapper(chat_agent):
|
||||
messages = [UserMessage(content="Test message")]
|
||||
shields = [ShieldDefinition(shield_type="test_shield")]
|
||||
shields = ["test_shield"]
|
||||
|
||||
responses = [
|
||||
chunk
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, ShieldDefinition
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
|
||||
|
||||
|
@ -21,8 +21,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
|||
self,
|
||||
tool: BaseTool,
|
||||
safety_api: Safety,
|
||||
input_shields: List[ShieldDefinition] = None,
|
||||
output_shields: List[ShieldDefinition] = None,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
):
|
||||
self._tool = tool
|
||||
ShieldRunnerMixin.__init__(
|
||||
|
@ -30,7 +30,6 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
# return the name of the wrapped tool
|
||||
return self._tool.get_name()
|
||||
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
|
@ -47,8 +46,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
|||
def with_safety(
|
||||
tool: BaseTool,
|
||||
safety_api: Safety,
|
||||
input_shields: List[ShieldDefinition] = None,
|
||||
output_shields: List[ShieldDefinition] = None,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
) -> SafeTool:
|
||||
return SafeTool(
|
||||
tool,
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.sku_list import CoreModelId, safety_models
|
||||
|
@ -11,6 +12,13 @@ from llama_models.sku_list import CoreModelId, safety_models
|
|||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class MetaReferenceShieldType(Enum):
|
||||
llama_guard = "llama_guard"
|
||||
code_scanner_guard = "code_scanner_guard"
|
||||
injection_shield = "injection_shield"
|
||||
jailbreak_shield = "jailbreak_shield"
|
||||
|
||||
|
||||
class LlamaGuardShieldConfig(BaseModel):
|
||||
model: str = "Llama-Guard-3-8B"
|
||||
excluded_categories: List[str] = []
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from .config import MetaReferenceShieldType, SafetyConfig
|
||||
|
||||
from .config import SafetyConfig
|
||||
from .shields import (
|
||||
CodeScannerShield,
|
||||
InjectionShield,
|
||||
|
@ -19,10 +19,16 @@ from .shields import (
|
|||
LlamaGuardShield,
|
||||
PromptGuardShield,
|
||||
ShieldBase,
|
||||
ThirdPartyShield,
|
||||
)
|
||||
|
||||
|
||||
def resolve_and_get_path(model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert model is not None, f"Could not resolve model {model_name}"
|
||||
model_dir = model_local_dir(model.descriptor())
|
||||
return model_dir
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
def __init__(self, config: SafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
@ -45,45 +51,56 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: ShieldType,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
assert shield_type in [
|
||||
"llama_guard",
|
||||
"prompt_guard",
|
||||
], f"Unknown shield {shield_type}"
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
assert shield_type in available_shields, f"Unknown shield {shield_type}"
|
||||
|
||||
raise NotImplementedError()
|
||||
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
def shield_type_equals(a: ShieldType, b: ShieldType):
|
||||
return a == b or a == b.value
|
||||
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
||||
res = await shield.run(messages)
|
||||
violation = None
|
||||
if res.is_violation:
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message=res.violation_return_message,
|
||||
metadata={
|
||||
"violation_type": res.violation_type,
|
||||
},
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def shield_config_to_shield(
|
||||
sc: ShieldDefinition, safety_config: SafetyConfig
|
||||
) -> ShieldBase:
|
||||
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
|
||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
cfg = self.config
|
||||
if typ == MetaReferenceShieldType.llama_guard:
|
||||
assert (
|
||||
safety_config.llama_guard_shield is not None
|
||||
cfg.llama_guard_shield is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
|
||||
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
|
||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
||||
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
|
||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||
assert (
|
||||
safety_config.prompt_guard_shield is not None
|
||||
cfg.prompt_guard_shield is not None
|
||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||
return JailbreakShield.instance(model_dir)
|
||||
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
|
||||
elif typ == MetaReferenceShieldType.injection_shield:
|
||||
assert (
|
||||
safety_config.prompt_guard_shield is not None
|
||||
cfg.prompt_guard_shield is not None
|
||||
), "Cannot use PromptGuardShield since not present in config"
|
||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
||||
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
|
||||
elif typ == MetaReferenceShieldType.code_scanner_guard:
|
||||
return CodeScannerShield.instance()
|
||||
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
|
||||
return ThirdPartyShield.instance()
|
||||
else:
|
||||
raise ValueError(f"Unknown shield type: {sc.shield_type}")
|
||||
raise ValueError(f"Unknown shield type: {typ}")
|
||||
|
|
|
@ -8,11 +8,26 @@ from abc import ABC, abstractmethod
|
|||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from pydantic import BaseModel
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
||||
# TODO: clean this up; just remove this type completely
|
||||
class ShieldResponse(BaseModel):
|
||||
is_violation: bool
|
||||
violation_type: Optional[str] = None
|
||||
violation_return_message: Optional[str] = None
|
||||
|
||||
|
||||
# TODO: this is a caller / agent concern
|
||||
class OnViolationAction(Enum):
|
||||
IGNORE = 0
|
||||
WARN = 1
|
||||
RAISE = 2
|
||||
|
||||
|
||||
class ShieldBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,35 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
ShieldBase,
|
||||
ShieldResponse,
|
||||
)
|
||||
|
||||
_INSTANCE = None
|
||||
|
||||
|
||||
class ThirdPartyShield(ShieldBase):
|
||||
@staticmethod
|
||||
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
|
||||
global _INSTANCE
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = ThirdPartyShield(on_violation_action)
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
super.run() # will raise NotImplementedError
|
Loading…
Add table
Add a link
Reference in a new issue