diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5cc9ce242..bfbd5616e 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -37,8 +37,8 @@ class AgentTool(Enum): class ToolDefinitionCommon(BaseModel): - input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) - output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) + input_shields: Optional[List[str]] = Field(default_factory=list) + output_shields: Optional[List[str]] = Field(default_factory=list) class SearchEngineType(Enum): @@ -266,8 +266,8 @@ class Session(BaseModel): class AgentConfigCommon(BaseModel): sampling_params: Optional[SamplingParams] = SamplingParams() - input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) - output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) + input_shields: Optional[List[str]] = Field(default_factory=list) + output_shields: Optional[List[str]] = Field(default_factory=list) tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py index 27ddc8dd5..2e2236c8f 100644 --- a/llama_stack/apis/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -13,11 +13,11 @@ import fire import httpx from llama_models.llama3.api.datatypes import UserMessage - -from llama_stack.distribution.datatypes import RemoteProviderConfig from pydantic import BaseModel from termcolor import cprint +from llama_stack.distribution.datatypes import RemoteProviderConfig + from .safety import * # noqa: F403 @@ -69,11 +69,7 @@ async def run_main(host: str, port: int): response = await client.run_shields( RunShieldRequest( messages=[message], - shields=[ - ShieldDefinition( - shield_type=BuiltinShield.llama_guard, - ) - ], + shields=["llama_guard"], ) ) print(response) diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 4775da131..cb8eb3c4a 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -37,11 +37,8 @@ class RunShieldResponse(BaseModel): violation: Optional[SafetyViolation] = None -ShieldType = str - - class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield( - self, shield: ShieldType, messages: List[Message], params: Dict[str, Any] = None + self, shield: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index 8bbf6b466..04e56056d 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -4,51 +4,46 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio + from typing import List -from llama_models.llama3.api.datatypes import Message, Role, UserMessage +from llama_models.llama3.api.datatypes import Message from termcolor import cprint -from llama_stack.apis.safety import ( - OnViolationAction, - Safety, - ShieldDefinition, - ShieldResponse, -) +from llama_stack.apis.safety import * # noqa: F403 class SafetyException(Exception): # noqa: N818 - def __init__(self, response: ShieldResponse): - self.response = response - super().__init__(response.violation_return_message) + def __init__(self, violation: SafetyViolation): + self.violation = violation + super().__init__(violation.user_message) class ShieldRunnerMixin: def __init__( self, safety_api: Safety, - input_shields: List[ShieldDefinition] = None, - output_shields: List[ShieldDefinition] = None, + input_shields: List[str] = None, + output_shields: List[str] = None, ): self.safety_api = safety_api self.input_shields = input_shields self.output_shields = output_shields - async def run_shields( - self, messages: List[Message], shields: List[ShieldDefinition] - ) -> List[ShieldResponse]: - messages = messages.copy() - # some shields like llama-guard require the first message to be a user message - # since this might be a tool call, first role might not be user - if len(messages) > 0 and messages[0].role != Role.user.value: - messages[0] = UserMessage(content=messages[0].content) - - results = await self.safety_api.run_shields( - messages=messages, - shields=shields, + async def run_shields(self, messages: List[Message], shields: List[str]) -> None: + responses = await asyncio.gather( + *[ + self.safety_api.run_shield( + shield_type=shield_type, + messages=messages, + ) + for shield_type in shields + ] ) - for shield, r in zip(shields, results): - if r.is_violation: + + for shield, r in zip(shields, responses): + if r.violation: if shield.on_violation_action == OnViolationAction.RAISE: raise SafetyException(r) elif shield.on_violation_action == OnViolationAction.WARN: @@ -56,5 +51,3 @@ class ShieldRunnerMixin: f"[Warn]{shield.__class__.__name__} raised a warning", color="red", ) - - return results diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py index 43d159e69..6cbd6000c 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +++ b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py @@ -223,7 +223,7 @@ async def test_chat_agent_create_and_execute_turn(chat_agent): @pytest.mark.asyncio async def test_run_shields_wrapper(chat_agent): messages = [UserMessage(content="Test message")] - shields = [ShieldDefinition(shield_type="test_shield")] + shields = ["test_shield"] responses = [ chunk diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/safety.py b/llama_stack/providers/impls/meta_reference/agents/tools/safety.py index d36dc3490..58bfbfeb4 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tools/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/tools/safety.py @@ -7,7 +7,7 @@ from typing import List from llama_stack.apis.inference import Message -from llama_stack.apis.safety import Safety, ShieldDefinition +from llama_stack.apis.safety import * # noqa: F403 from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin @@ -21,8 +21,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin): self, tool: BaseTool, safety_api: Safety, - input_shields: List[ShieldDefinition] = None, - output_shields: List[ShieldDefinition] = None, + input_shields: List[str] = None, + output_shields: List[str] = None, ): self._tool = tool ShieldRunnerMixin.__init__( @@ -30,7 +30,6 @@ class SafeTool(BaseTool, ShieldRunnerMixin): ) def get_name(self) -> str: - # return the name of the wrapped tool return self._tool.get_name() async def run(self, messages: List[Message]) -> List[Message]: @@ -47,8 +46,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin): def with_safety( tool: BaseTool, safety_api: Safety, - input_shields: List[ShieldDefinition] = None, - output_shields: List[ShieldDefinition] = None, + input_shields: List[str] = None, + output_shields: List[str] = None, ) -> SafeTool: return SafeTool( tool, diff --git a/llama_stack/providers/impls/meta_reference/safety/config.py b/llama_stack/providers/impls/meta_reference/safety/config.py index 4d68d2e48..98751cf3e 100644 --- a/llama_stack/providers/impls/meta_reference/safety/config.py +++ b/llama_stack/providers/impls/meta_reference/safety/config.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import List, Optional from llama_models.sku_list import CoreModelId, safety_models @@ -11,6 +12,13 @@ from llama_models.sku_list import CoreModelId, safety_models from pydantic import BaseModel, validator +class MetaReferenceShieldType(Enum): + llama_guard = "llama_guard" + code_scanner_guard = "code_scanner_guard" + injection_shield = "injection_shield" + jailbreak_shield = "jailbreak_shield" + + class LlamaGuardShieldConfig(BaseModel): model: str = "Llama-Guard-3-8B" excluded_categories: List[str] = [] diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 090064a32..6eccf47a5 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio - from llama_models.sku_list import resolve_model from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.apis.safety import * # noqa +from llama_stack.apis.safety import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 + +from .config import MetaReferenceShieldType, SafetyConfig -from .config import SafetyConfig from .shields import ( CodeScannerShield, InjectionShield, @@ -19,10 +19,16 @@ from .shields import ( LlamaGuardShield, PromptGuardShield, ShieldBase, - ThirdPartyShield, ) +def resolve_and_get_path(model_name: str) -> str: + model = resolve_model(model_name) + assert model is not None, f"Could not resolve model {model_name}" + model_dir = model_local_dir(model.descriptor()) + return model_dir + + class MetaReferenceSafetyImpl(Safety): def __init__(self, config: SafetyConfig) -> None: self.config = config @@ -45,45 +51,56 @@ class MetaReferenceSafetyImpl(Safety): async def run_shield( self, - shield_type: ShieldType, + shield_type: str, messages: List[Message], + params: Dict[str, Any] = None, ) -> RunShieldResponse: - assert shield_type in [ - "llama_guard", - "prompt_guard", - ], f"Unknown shield {shield_type}" + available_shields = [v.value for v in MetaReferenceShieldType] + assert shield_type in available_shields, f"Unknown shield {shield_type}" - raise NotImplementedError() + shield = self.get_shield_impl(MetaReferenceShieldType(shield_type)) + messages = messages.copy() + # some shields like llama-guard require the first message to be a user message + # since this might be a tool call, first role might not be user + if len(messages) > 0 and messages[0].role != Role.user.value: + messages[0] = UserMessage(content=messages[0].content) -def shield_type_equals(a: ShieldType, b: ShieldType): - return a == b or a == b.value + # TODO: we can refactor ShieldBase, etc. to be inline with the API types + res = await shield.run(messages) + violation = None + if res.is_violation: + violation = SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message=res.violation_return_message, + metadata={ + "violation_type": res.violation_type, + }, + ) + return RunShieldResponse(violation=violation) -def shield_config_to_shield( - sc: ShieldDefinition, safety_config: SafetyConfig -) -> ShieldBase: - if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard): - assert ( - safety_config.llama_guard_shield is not None - ), "Cannot use LlamaGuardShield since not present in config" - model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model) - return LlamaGuardShield.instance(model_dir=model_dir) - elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield): - assert ( - safety_config.prompt_guard_shield is not None - ), "Cannot use Jailbreak Shield since Prompt Guard not present in config" - model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) - return JailbreakShield.instance(model_dir) - elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield): - assert ( - safety_config.prompt_guard_shield is not None - ), "Cannot use PromptGuardShield since not present in config" - model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) - return InjectionShield.instance(model_dir) - elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard): - return CodeScannerShield.instance() - elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield): - return ThirdPartyShield.instance() - else: - raise ValueError(f"Unknown shield type: {sc.shield_type}") + def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: + cfg = self.config + if typ == MetaReferenceShieldType.llama_guard: + assert ( + cfg.llama_guard_shield is not None + ), "Cannot use LlamaGuardShield since not present in config" + model_dir = resolve_and_get_path(cfg.llama_guard_shield.model) + return LlamaGuardShield.instance(model_dir=model_dir) + elif typ == MetaReferenceShieldType.jailbreak_shield: + assert ( + cfg.prompt_guard_shield is not None + ), "Cannot use Jailbreak Shield since Prompt Guard not present in config" + model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model) + return JailbreakShield.instance(model_dir) + elif typ == MetaReferenceShieldType.injection_shield: + assert ( + cfg.prompt_guard_shield is not None + ), "Cannot use PromptGuardShield since not present in config" + model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model) + return InjectionShield.instance(model_dir) + elif typ == MetaReferenceShieldType.code_scanner_guard: + return CodeScannerShield.instance() + else: + raise ValueError(f"Unknown shield type: {typ}") diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/base.py b/llama_stack/providers/impls/meta_reference/safety/shields/base.py index 64e64e2fd..86124b1e5 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/base.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/base.py @@ -8,11 +8,26 @@ from abc import ABC, abstractmethod from typing import List from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message +from pydantic import BaseModel from llama_stack.apis.safety import * # noqa: F403 CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" +# TODO: clean this up; just remove this type completely +class ShieldResponse(BaseModel): + is_violation: bool + violation_type: Optional[str] = None + violation_return_message: Optional[str] = None + + +# TODO: this is a caller / agent concern +class OnViolationAction(Enum): + IGNORE = 0 + WARN = 1 + RAISE = 2 + + class ShieldBase(ABC): def __init__( self, diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/contrib/__init__.py b/llama_stack/providers/impls/meta_reference/safety/shields/contrib/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/impls/meta_reference/safety/shields/contrib/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/contrib/third_party_shield.py b/llama_stack/providers/impls/meta_reference/safety/shields/contrib/third_party_shield.py deleted file mode 100644 index cc652ae63..000000000 --- a/llama_stack/providers/impls/meta_reference/safety/shields/contrib/third_party_shield.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import List - -from llama_models.llama3.api.datatypes import Message - -from llama_stack.providers.impls.meta_reference.safety.shields.base import ( - OnViolationAction, - ShieldBase, - ShieldResponse, -) - -_INSTANCE = None - - -class ThirdPartyShield(ShieldBase): - @staticmethod - def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield": - global _INSTANCE - if _INSTANCE is None: - _INSTANCE = ThirdPartyShield(on_violation_action) - return _INSTANCE - - def __init__( - self, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - ): - super().__init__(on_violation_action) - - async def run(self, messages: List[Message]) -> ShieldResponse: - super.run() # will raise NotImplementedError