diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 2733dde73..4775da131 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -5,87 +5,43 @@ # the root directory of this source tree. from enum import Enum -from typing import Dict, List, Optional, Protocol, Union +from typing import Any, Dict, Protocol from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, validator +from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig @json_schema_type -class BuiltinShield(Enum): - llama_guard = "llama_guard" - code_scanner_guard = "code_scanner_guard" - third_party_shield = "third_party_shield" - injection_shield = "injection_shield" - jailbreak_shield = "jailbreak_shield" - - -ShieldType = Union[BuiltinShield, str] +class ViolationLevel(Enum): + INFO = "info" + WARN = "warn" + ERROR = "error" @json_schema_type -class OnViolationAction(Enum): - IGNORE = 0 - WARN = 1 - RAISE = 2 +class SafetyViolation(BaseModel): + violation_level: ViolationLevel + # what message should you convey to the user + user_message: Optional[str] = None -@json_schema_type -class ShieldDefinition(BaseModel): - shield_type: ShieldType - description: Optional[str] = None - parameters: Optional[Dict[str, ToolParamDefinition]] = None - on_violation_action: OnViolationAction = OnViolationAction.RAISE - execution_config: Optional[RestAPIExecutionConfig] = None - - @validator("shield_type", pre=True) - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinShield(v) - except ValueError: - return v - return v - - -@json_schema_type -class ShieldResponse(BaseModel): - shield_type: ShieldType - # TODO(ashwin): clean this up - is_violation: bool - violation_type: Optional[str] = None - violation_return_message: Optional[str] = None - - @validator("shield_type", pre=True) - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinShield(v) - except ValueError: - return v - return v - - -@json_schema_type -class RunShieldRequest(BaseModel): - messages: List[Message] - shields: List[ShieldDefinition] + # additional metadata (including specific violation codes) more for + # debugging, telemetry + metadata: Dict[str, Any] = Field(default_factory=dict) @json_schema_type class RunShieldResponse(BaseModel): - responses: List[ShieldResponse] + violation: Optional[SafetyViolation] = None + + +ShieldType = str class Safety(Protocol): - @webmethod(route="/safety/run_shields") - async def run_shields( - self, - messages: List[Message], - shields: List[ShieldDefinition], + @webmethod(route="/safety/run_shield") + async def run_shield( + self, shield: ShieldType, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... diff --git a/llama_stack/providers/adapters/safety/__init__.py b/llama_stack/providers/adapters/safety/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/adapters/safety/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/adapters/safety/bedrock/__init__.py b/llama_stack/providers/adapters/safety/bedrock/__init__.py new file mode 100644 index 000000000..fd6ad5343 --- /dev/null +++ b/llama_stack/providers/adapters/safety/bedrock/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.distribution.datatypes import RemoteProviderConfig + +from .config import BedrockSafetyRequestProviderData # noqa: F403 + + +async def get_adapter_impl(config: RemoteProviderConfig, _deps): + from .bedrock import BedrockSafetyAdapter + + impl = BedrockSafetyAdapter(config.url) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py new file mode 100644 index 000000000..f746eaa24 --- /dev/null +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.providers.utils import get_request_provider_data + +from .config import BedrockSafetyRequestProviderData + + +class BedrockSafetyAdapter(Safety): + def __init__(self, url: str) -> None: + self.url = url + pass + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_shield( + self, + shield: ShieldType, + messages: List[Message], + ) -> RunShieldResponse: + # clients will set api_keys by doing something like: + # + # client = llama_stack.LlamaStack() + # await client.safety.run_shield( + # shield_type="aws_guardrail_type", + # messages=[ ... ], + # x_llamastack_provider_data={ + # "aws_api_key": "..." + # } + # ) + # + # This information will arrive at the LlamaStack server via a HTTP Header. + # + # The server will then provide you a type-checked version of this provider data + # automagically by extracting it from the header and validating it with the + # BedrockSafetyRequestProviderData class you will need to register in the provider + # registry. + # + provider_data: BedrockSafetyRequestProviderData = get_request_provider_data() + # use `aws_api_key` to pass to the AWS servers in whichever form + + raise NotImplementedError() diff --git a/llama_stack/providers/adapters/safety/bedrock/config.py b/llama_stack/providers/adapters/safety/bedrock/config.py new file mode 100644 index 000000000..344048469 --- /dev/null +++ b/llama_stack/providers/adapters/safety/bedrock/config.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class BedrockSafetyRequestProviderData(BaseModel): + aws_api_key: str + # other AWS specific keys you may need diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index baf0ebb46..090064a32 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -23,13 +23,6 @@ from .shields import ( ) -def resolve_and_get_path(model_name: str) -> str: - model = resolve_model(model_name) - assert model is not None, f"Could not resolve model {model_name}" - model_dir = model_local_dir(model.descriptor()) - return model_dir - - class MetaReferenceSafetyImpl(Safety): def __init__(self, config: SafetyConfig) -> None: self.config = config @@ -50,16 +43,17 @@ class MetaReferenceSafetyImpl(Safety): model_dir = resolve_and_get_path(shield_cfg.model) _ = PromptGuardShield.instance(model_dir) - async def run_shields( + async def run_shield( self, + shield_type: ShieldType, messages: List[Message], - shields: List[ShieldDefinition], ) -> RunShieldResponse: - shields = [shield_config_to_shield(c, self.config) for c in shields] + assert shield_type in [ + "llama_guard", + "prompt_guard", + ], f"Unknown shield {shield_type}" - responses = await asyncio.gather(*[shield.run(messages) for shield in shields]) - - return RunShieldResponse(responses=responses) + raise NotImplementedError() def shield_type_equals(a: ShieldType, b: ShieldType): diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 6e9583066..bbb1dd5a9 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -6,7 +6,12 @@ from typing import List -from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.distribution.datatypes import ( + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: @@ -23,4 +28,15 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.impls.meta_reference.safety", config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", ), + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_id="bedrock", + pip_packages=[ + "aws-sdk", + ], + module="llama_stack.providers.adapters.safety.bedrock", + header_extractor="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyRequestProviderData", + ), + ), ]