safety API cleanup part 1

Sample adapter implementation for Bedrock implementation of Guardrails
This commit is contained in:
Ashwin Bharambe 2024-09-20 10:57:26 -07:00 committed by Xi Yan
parent 32beecb20d
commit 7e40eead4e
7 changed files with 130 additions and 78 deletions

View file

@ -5,87 +5,43 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Protocol, Union from typing import Any, Dict, Protocol
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, validator from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
@json_schema_type @json_schema_type
class BuiltinShield(Enum): class ViolationLevel(Enum):
llama_guard = "llama_guard" INFO = "info"
code_scanner_guard = "code_scanner_guard" WARN = "warn"
third_party_shield = "third_party_shield" ERROR = "error"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
ShieldType = Union[BuiltinShield, str]
@json_schema_type @json_schema_type
class OnViolationAction(Enum): class SafetyViolation(BaseModel):
IGNORE = 0 violation_level: ViolationLevel
WARN = 1
RAISE = 2
# what message should you convey to the user
user_message: Optional[str] = None
@json_schema_type # additional metadata (including specific violation codes) more for
class ShieldDefinition(BaseModel): # debugging, telemetry
shield_type: ShieldType metadata: Dict[str, Any] = Field(default_factory=dict)
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v
@json_schema_type
class ShieldResponse(BaseModel):
shield_type: ShieldType
# TODO(ashwin): clean this up
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v
@json_schema_type
class RunShieldRequest(BaseModel):
messages: List[Message]
shields: List[ShieldDefinition]
@json_schema_type @json_schema_type
class RunShieldResponse(BaseModel): class RunShieldResponse(BaseModel):
responses: List[ShieldResponse] violation: Optional[SafetyViolation] = None
ShieldType = str
class Safety(Protocol): class Safety(Protocol):
@webmethod(route="/safety/run_shields") @webmethod(route="/safety/run_shield")
async def run_shields( async def run_shield(
self, self, shield: ShieldType, messages: List[Message], params: Dict[str, Any] = None
messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse: ... ) -> RunShieldResponse: ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .config import BedrockSafetyRequestProviderData # noqa: F403
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .bedrock import BedrockSafetyAdapter
impl = BedrockSafetyAdapter(config.url)
await impl.initialize()
return impl

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils import get_request_provider_data
from .config import BedrockSafetyRequestProviderData
class BedrockSafetyAdapter(Safety):
def __init__(self, url: str) -> None:
self.url = url
pass
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_shield(
self,
shield: ShieldType,
messages: List[Message],
) -> RunShieldResponse:
# clients will set api_keys by doing something like:
#
# client = llama_stack.LlamaStack()
# await client.safety.run_shield(
# shield_type="aws_guardrail_type",
# messages=[ ... ],
# x_llamastack_provider_data={
# "aws_api_key": "..."
# }
# )
#
# This information will arrive at the LlamaStack server via a HTTP Header.
#
# The server will then provide you a type-checked version of this provider data
# automagically by extracting it from the header and validating it with the
# BedrockSafetyRequestProviderData class you will need to register in the provider
# registry.
#
provider_data: BedrockSafetyRequestProviderData = get_request_provider_data()
# use `aws_api_key` to pass to the AWS servers in whichever form
raise NotImplementedError()

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class BedrockSafetyRequestProviderData(BaseModel):
aws_api_key: str
# other AWS specific keys you may need

View file

@ -23,13 +23,6 @@ from .shields import (
) )
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model.descriptor())
return model_dir
class MetaReferenceSafetyImpl(Safety): class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None: def __init__(self, config: SafetyConfig) -> None:
self.config = config self.config = config
@ -50,16 +43,17 @@ class MetaReferenceSafetyImpl(Safety):
model_dir = resolve_and_get_path(shield_cfg.model) model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir) _ = PromptGuardShield.instance(model_dir)
async def run_shields( async def run_shield(
self, self,
shield_type: ShieldType,
messages: List[Message], messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse: ) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in shields] assert shield_type in [
"llama_guard",
"prompt_guard",
], f"Unknown shield {shield_type}"
responses = await asyncio.gather(*[shield.run(messages) for shield in shields]) raise NotImplementedError()
return RunShieldResponse(responses=responses)
def shield_type_equals(a: ShieldType, b: ShieldType): def shield_type_equals(a: ShieldType, b: ShieldType):

View file

@ -6,7 +6,12 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec from llama_stack.distribution.datatypes import (
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:
@ -23,4 +28,15 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.impls.meta_reference.safety", module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
), ),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_id="bedrock",
pip_packages=[
"aws-sdk",
],
module="llama_stack.providers.adapters.safety.bedrock",
header_extractor="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyRequestProviderData",
),
),
] ]