mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Update the meta reference safety implementation to match new API
This commit is contained in:
parent
93e4ef3829
commit
51245a417b
11 changed files with 115 additions and 130 deletions
|
@ -37,8 +37,8 @@ class AgentTool(Enum):
|
||||||
|
|
||||||
|
|
||||||
class ToolDefinitionCommon(BaseModel):
|
class ToolDefinitionCommon(BaseModel):
|
||||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class SearchEngineType(Enum):
|
class SearchEngineType(Enum):
|
||||||
|
@ -266,8 +266,8 @@ class Session(BaseModel):
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
|
|
|
@ -13,11 +13,11 @@ import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import UserMessage
|
from llama_models.llama3.api.datatypes import UserMessage
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .safety import * # noqa: F403
|
from .safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,11 +69,7 @@ async def run_main(host: str, port: int):
|
||||||
response = await client.run_shields(
|
response = await client.run_shields(
|
||||||
RunShieldRequest(
|
RunShieldRequest(
|
||||||
messages=[message],
|
messages=[message],
|
||||||
shields=[
|
shields=["llama_guard"],
|
||||||
ShieldDefinition(
|
|
||||||
shield_type=BuiltinShield.llama_guard,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
|
@ -37,11 +37,8 @@ class RunShieldResponse(BaseModel):
|
||||||
violation: Optional[SafetyViolation] = None
|
violation: Optional[SafetyViolation] = None
|
||||||
|
|
||||||
|
|
||||||
ShieldType = str
|
|
||||||
|
|
||||||
|
|
||||||
class Safety(Protocol):
|
class Safety(Protocol):
|
||||||
@webmethod(route="/safety/run_shield")
|
@webmethod(route="/safety/run_shield")
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield: ShieldType, messages: List[Message], params: Dict[str, Any] = None
|
self, shield: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
|
@ -4,51 +4,46 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.safety import (
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
OnViolationAction,
|
|
||||||
Safety,
|
|
||||||
ShieldDefinition,
|
|
||||||
ShieldResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
def __init__(self, response: ShieldResponse):
|
def __init__(self, violation: SafetyViolation):
|
||||||
self.response = response
|
self.violation = violation
|
||||||
super().__init__(response.violation_return_message)
|
super().__init__(violation.user_message)
|
||||||
|
|
||||||
|
|
||||||
class ShieldRunnerMixin:
|
class ShieldRunnerMixin:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
input_shields: List[ShieldDefinition] = None,
|
input_shields: List[str] = None,
|
||||||
output_shields: List[ShieldDefinition] = None,
|
output_shields: List[str] = None,
|
||||||
):
|
):
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.input_shields = input_shields
|
self.input_shields = input_shields
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_shields(
|
async def run_shields(self, messages: List[Message], shields: List[str]) -> None:
|
||||||
self, messages: List[Message], shields: List[ShieldDefinition]
|
responses = await asyncio.gather(
|
||||||
) -> List[ShieldResponse]:
|
*[
|
||||||
messages = messages.copy()
|
self.safety_api.run_shield(
|
||||||
# some shields like llama-guard require the first message to be a user message
|
shield_type=shield_type,
|
||||||
# since this might be a tool call, first role might not be user
|
messages=messages,
|
||||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
)
|
||||||
messages[0] = UserMessage(content=messages[0].content)
|
for shield_type in shields
|
||||||
|
]
|
||||||
results = await self.safety_api.run_shields(
|
|
||||||
messages=messages,
|
|
||||||
shields=shields,
|
|
||||||
)
|
)
|
||||||
for shield, r in zip(shields, results):
|
|
||||||
if r.is_violation:
|
for shield, r in zip(shields, responses):
|
||||||
|
if r.violation:
|
||||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||||
raise SafetyException(r)
|
raise SafetyException(r)
|
||||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
elif shield.on_violation_action == OnViolationAction.WARN:
|
||||||
|
@ -56,5 +51,3 @@ class ShieldRunnerMixin:
|
||||||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
f"[Warn]{shield.__class__.__name__} raised a warning",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
|
@ -223,7 +223,7 @@ async def test_chat_agent_create_and_execute_turn(chat_agent):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_shields_wrapper(chat_agent):
|
async def test_run_shields_wrapper(chat_agent):
|
||||||
messages = [UserMessage(content="Test message")]
|
messages = [UserMessage(content="Test message")]
|
||||||
shields = [ShieldDefinition(shield_type="test_shield")]
|
shields = ["test_shield"]
|
||||||
|
|
||||||
responses = [
|
responses = [
|
||||||
chunk
|
chunk
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import Safety, ShieldDefinition
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
|
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
|
||||||
|
|
||||||
|
@ -21,8 +21,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
||||||
self,
|
self,
|
||||||
tool: BaseTool,
|
tool: BaseTool,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
input_shields: List[ShieldDefinition] = None,
|
input_shields: List[str] = None,
|
||||||
output_shields: List[ShieldDefinition] = None,
|
output_shields: List[str] = None,
|
||||||
):
|
):
|
||||||
self._tool = tool
|
self._tool = tool
|
||||||
ShieldRunnerMixin.__init__(
|
ShieldRunnerMixin.__init__(
|
||||||
|
@ -30,7 +30,6 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
# return the name of the wrapped tool
|
|
||||||
return self._tool.get_name()
|
return self._tool.get_name()
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> List[Message]:
|
async def run(self, messages: List[Message]) -> List[Message]:
|
||||||
|
@ -47,8 +46,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
||||||
def with_safety(
|
def with_safety(
|
||||||
tool: BaseTool,
|
tool: BaseTool,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
input_shields: List[ShieldDefinition] = None,
|
input_shields: List[str] = None,
|
||||||
output_shields: List[ShieldDefinition] = None,
|
output_shields: List[str] = None,
|
||||||
) -> SafeTool:
|
) -> SafeTool:
|
||||||
return SafeTool(
|
return SafeTool(
|
||||||
tool,
|
tool,
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.sku_list import CoreModelId, safety_models
|
from llama_models.sku_list import CoreModelId, safety_models
|
||||||
|
@ -11,6 +12,13 @@ from llama_models.sku_list import CoreModelId, safety_models
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
|
|
||||||
|
class MetaReferenceShieldType(Enum):
|
||||||
|
llama_guard = "llama_guard"
|
||||||
|
code_scanner_guard = "code_scanner_guard"
|
||||||
|
injection_shield = "injection_shield"
|
||||||
|
jailbreak_shield = "jailbreak_shield"
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShieldConfig(BaseModel):
|
class LlamaGuardShieldConfig(BaseModel):
|
||||||
model: str = "Llama-Guard-3-8B"
|
model: str = "Llama-Guard-3-8B"
|
||||||
excluded_categories: List[str] = []
|
excluded_categories: List[str] = []
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.apis.safety import * # noqa
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from .config import MetaReferenceShieldType, SafetyConfig
|
||||||
|
|
||||||
from .config import SafetyConfig
|
|
||||||
from .shields import (
|
from .shields import (
|
||||||
CodeScannerShield,
|
CodeScannerShield,
|
||||||
InjectionShield,
|
InjectionShield,
|
||||||
|
@ -19,10 +19,16 @@ from .shields import (
|
||||||
LlamaGuardShield,
|
LlamaGuardShield,
|
||||||
PromptGuardShield,
|
PromptGuardShield,
|
||||||
ShieldBase,
|
ShieldBase,
|
||||||
ThirdPartyShield,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_and_get_path(model_name: str) -> str:
|
||||||
|
model = resolve_model(model_name)
|
||||||
|
assert model is not None, f"Could not resolve model {model_name}"
|
||||||
|
model_dir = model_local_dir(model.descriptor())
|
||||||
|
return model_dir
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety):
|
class MetaReferenceSafetyImpl(Safety):
|
||||||
def __init__(self, config: SafetyConfig) -> None:
|
def __init__(self, config: SafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -45,45 +51,56 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_type: ShieldType,
|
shield_type: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
assert shield_type in [
|
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||||
"llama_guard",
|
assert shield_type in available_shields, f"Unknown shield {shield_type}"
|
||||||
"prompt_guard",
|
|
||||||
], f"Unknown shield {shield_type}"
|
|
||||||
|
|
||||||
raise NotImplementedError()
|
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
|
||||||
|
|
||||||
|
messages = messages.copy()
|
||||||
|
# some shields like llama-guard require the first message to be a user message
|
||||||
|
# since this might be a tool call, first role might not be user
|
||||||
|
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||||
|
messages[0] = UserMessage(content=messages[0].content)
|
||||||
|
|
||||||
def shield_type_equals(a: ShieldType, b: ShieldType):
|
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
||||||
return a == b or a == b.value
|
res = await shield.run(messages)
|
||||||
|
violation = None
|
||||||
|
if res.is_violation:
|
||||||
|
violation = SafetyViolation(
|
||||||
|
violation_level=ViolationLevel.ERROR,
|
||||||
|
user_message=res.violation_return_message,
|
||||||
|
metadata={
|
||||||
|
"violation_type": res.violation_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
def shield_config_to_shield(
|
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||||
sc: ShieldDefinition, safety_config: SafetyConfig
|
cfg = self.config
|
||||||
) -> ShieldBase:
|
if typ == MetaReferenceShieldType.llama_guard:
|
||||||
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
|
assert (
|
||||||
assert (
|
cfg.llama_guard_shield is not None
|
||||||
safety_config.llama_guard_shield is not None
|
), "Cannot use LlamaGuardShield since not present in config"
|
||||||
), "Cannot use LlamaGuardShield since not present in config"
|
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
|
||||||
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
|
return LlamaGuardShield.instance(model_dir=model_dir)
|
||||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||||
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
|
assert (
|
||||||
assert (
|
cfg.prompt_guard_shield is not None
|
||||||
safety_config.prompt_guard_shield is not None
|
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
||||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
return JailbreakShield.instance(model_dir)
|
||||||
return JailbreakShield.instance(model_dir)
|
elif typ == MetaReferenceShieldType.injection_shield:
|
||||||
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
|
assert (
|
||||||
assert (
|
cfg.prompt_guard_shield is not None
|
||||||
safety_config.prompt_guard_shield is not None
|
), "Cannot use PromptGuardShield since not present in config"
|
||||||
), "Cannot use PromptGuardShield since not present in config"
|
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
|
||||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
return InjectionShield.instance(model_dir)
|
||||||
return InjectionShield.instance(model_dir)
|
elif typ == MetaReferenceShieldType.code_scanner_guard:
|
||||||
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
|
return CodeScannerShield.instance()
|
||||||
return CodeScannerShield.instance()
|
else:
|
||||||
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
|
raise ValueError(f"Unknown shield type: {typ}")
|
||||||
return ThirdPartyShield.instance()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown shield type: {sc.shield_type}")
|
|
||||||
|
|
|
@ -8,11 +8,26 @@ from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||||
|
from pydantic import BaseModel
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: clean this up; just remove this type completely
|
||||||
|
class ShieldResponse(BaseModel):
|
||||||
|
is_violation: bool
|
||||||
|
violation_type: Optional[str] = None
|
||||||
|
violation_return_message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this is a caller / agent concern
|
||||||
|
class OnViolationAction(Enum):
|
||||||
|
IGNORE = 0
|
||||||
|
WARN = 1
|
||||||
|
RAISE = 2
|
||||||
|
|
||||||
|
|
||||||
class ShieldBase(ABC):
|
class ShieldBase(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
|
@ -1,35 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
|
||||||
OnViolationAction,
|
|
||||||
ShieldBase,
|
|
||||||
ShieldResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
_INSTANCE = None
|
|
||||||
|
|
||||||
|
|
||||||
class ThirdPartyShield(ShieldBase):
|
|
||||||
@staticmethod
|
|
||||||
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
|
|
||||||
global _INSTANCE
|
|
||||||
if _INSTANCE is None:
|
|
||||||
_INSTANCE = ThirdPartyShield(on_violation_action)
|
|
||||||
return _INSTANCE
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
||||||
):
|
|
||||||
super().__init__(on_violation_action)
|
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
|
||||||
super.run() # will raise NotImplementedError
|
|
Loading…
Add table
Add a link
Reference in a new issue