diff --git a/docs/my-website/docs/proxy/guardrails/lakera_ai.md b/docs/my-website/docs/proxy/guardrails/lakera_ai.md index 1febd39e5..1863cb8b9 100644 --- a/docs/my-website/docs/proxy/guardrails/lakera_ai.md +++ b/docs/my-website/docs/proxy/guardrails/lakera_ai.md @@ -4,8 +4,8 @@ import TabItem from '@theme/TabItem'; # Lakera AI - -## 1. Define Guardrails on your LiteLLM config.yaml +## Quick Start +### 1. Define Guardrails on your LiteLLM config.yaml Define your guardrails under the `guardrails` section ```yaml @@ -22,23 +22,29 @@ guardrails: mode: "during_call" api_key: os.environ/LAKERA_API_KEY api_base: os.environ/LAKERA_API_BASE + - guardrail_name: "lakera-pre-guard" + litellm_params: + guardrail: lakera # supported values: "aporia", "bedrock", "lakera" + mode: "pre_call" + api_key: os.environ/LAKERA_API_KEY + api_base: os.environ/LAKERA_API_BASE ``` -### Supported values for `mode` +#### Supported values for `mode` - `pre_call` Run **before** LLM call, on **input** - `post_call` Run **after** LLM call, on **input & output** - `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes -## 2. Start LiteLLM Gateway +### 2. Start LiteLLM Gateway ```shell litellm --config config.yaml --detailed_debug ``` -## 3. Test request +### 3. Test request **[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)** @@ -120,4 +126,30 @@ curl -i http://localhost:4000/v1/chat/completions \ +## Advanced +### Set category-based thresholds. +Lakera has 2 categories for prompt_injection attacks: +- jailbreak +- prompt_injection + +```yaml +model_list: + - model_name: fake-openai-endpoint + litellm_params: + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + +guardrails: + - guardrail_name: "lakera-pre-guard" + litellm_params: + guardrail: lakera # supported values: "aporia", "bedrock", "lakera" + mode: "during_call" + api_key: os.environ/LAKERA_API_KEY + api_base: os.environ/LAKERA_API_BASE + category_thresholds: + prompt_injection: 0.1 + jailbreak: 0.1 + +``` \ No newline at end of file diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py index 8ee856da8..e4e440c34 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py @@ -25,7 +25,12 @@ from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata -from litellm.types.guardrails import GuardrailItem, Role, default_roles +from litellm.types.guardrails import ( + GuardrailItem, + LakeraCategoryThresholds, + Role, + default_roles, +) GUARDRAIL_NAME = "lakera_prompt_injection" @@ -36,16 +41,11 @@ INPUT_POSITIONING_MAP = { } -class LakeraCategories(TypedDict, total=False): - jailbreak: float - prompt_injection: float - - class lakeraAI_Moderation(CustomGuardrail): def __init__( self, moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel", - category_thresholds: Optional[LakeraCategories] = None, + category_thresholds: Optional[LakeraCategoryThresholds] = None, api_base: Optional[str] = None, api_key: Optional[str] = None, **kwargs, @@ -72,7 +72,7 @@ class lakeraAI_Moderation(CustomGuardrail): if self.category_thresholds is not None: if category_scores is not None: - typed_cat_scores = LakeraCategories(**category_scores) + typed_cat_scores = LakeraCategoryThresholds(**category_scores) if ( "jailbreak" in typed_cat_scores and "jailbreak" in self.category_thresholds @@ -219,6 +219,8 @@ class lakeraAI_Moderation(CustomGuardrail): text = "\n".join(data["input"]) _json_data = json.dumps({"input": text}) + verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data) + # https://platform.lakera.ai/account/api-keys """ @@ -288,7 +290,18 @@ class lakeraAI_Moderation(CustomGuardrail): "pass_through_endpoint", ], ) -> Optional[Union[Exception, str, Dict]]: - if self.moderation_check == "in_parallel": + from litellm.types.guardrails import GuardrailEventHooks + + if self.event_hook is None: + if self.moderation_check == "in_parallel": + return None + + if ( + self.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.pre_call + ) + is not True + ): return None return await self._check( diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 95267e6bb..ad99daf95 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -12,6 +12,7 @@ from litellm.types.guardrails import ( Guardrail, GuardrailItem, GuardrailItemSpec, + LakeraCategoryThresholds, LitellmParams, guardrailConfig, ) @@ -99,6 +100,15 @@ def init_guardrails_v2(all_guardrails: dict): api_base=litellm_params_data["api_base"], ) + if ( + "category_thresholds" in litellm_params_data + and litellm_params_data["category_thresholds"] + ): + lakera_category_thresholds = LakeraCategoryThresholds( + **litellm_params_data["category_thresholds"] + ) + litellm_params["category_thresholds"] = lakera_category_thresholds + if litellm_params["api_key"]: if litellm_params["api_key"].startswith("os.environ/"): litellm_params["api_key"] = litellm.get_secret( @@ -134,6 +144,7 @@ def init_guardrails_v2(all_guardrails: dict): api_key=litellm_params["api_key"], guardrail_name=guardrail["guardrail_name"], event_hook=litellm_params["mode"], + category_thresholds=litellm_params.get("category_thresholds"), ) litellm.callbacks.append(_lakera_callback) # type: ignore diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 8f19b7e04..57609d29b 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -12,4 +12,7 @@ guardrails: mode: "during_call" api_key: os.environ/LAKERA_API_KEY api_base: os.environ/LAKERA_API_BASE + category_thresholds: + prompt_injection: 0.1 + jailbreak: 0.1 \ No newline at end of file diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index cd9f76f17..66c2a535a 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -66,11 +66,17 @@ class GuardrailItem(BaseModel): # Define the TypedDicts -class LitellmParams(TypedDict): +class LakeraCategoryThresholds(TypedDict, total=False): + prompt_injection: float + jailbreak: float + + +class LitellmParams(TypedDict, total=False): guardrail: str mode: str api_key: str api_base: Optional[str] + category_thresholds: Optional[LakeraCategoryThresholds] class Guardrail(TypedDict):