support lakera ai category thresholds

This commit is contained in:
Ishaan Jaff 2024-08-20 17:19:24 -07:00
parent 30da63bd4f
commit 8d2c529e55
5 changed files with 80 additions and 15 deletions

View file

@ -4,8 +4,8 @@ import TabItem from '@theme/TabItem';
# Lakera AI # Lakera AI
## Quick Start
## 1. Define Guardrails on your LiteLLM config.yaml ### 1. Define Guardrails on your LiteLLM config.yaml
Define your guardrails under the `guardrails` section Define your guardrails under the `guardrails` section
```yaml ```yaml
@ -22,23 +22,29 @@ guardrails:
mode: "during_call" mode: "during_call"
api_key: os.environ/LAKERA_API_KEY api_key: os.environ/LAKERA_API_KEY
api_base: os.environ/LAKERA_API_BASE api_base: os.environ/LAKERA_API_BASE
- guardrail_name: "lakera-pre-guard"
litellm_params:
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
mode: "pre_call"
api_key: os.environ/LAKERA_API_KEY
api_base: os.environ/LAKERA_API_BASE
``` ```
### Supported values for `mode` #### Supported values for `mode`
- `pre_call` Run **before** LLM call, on **input** - `pre_call` Run **before** LLM call, on **input**
- `post_call` Run **after** LLM call, on **input & output** - `post_call` Run **after** LLM call, on **input & output**
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes - `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
## 2. Start LiteLLM Gateway ### 2. Start LiteLLM Gateway
```shell ```shell
litellm --config config.yaml --detailed_debug litellm --config config.yaml --detailed_debug
``` ```
## 3. Test request ### 3. Test request
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)** **[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)**
@ -120,4 +126,30 @@ curl -i http://localhost:4000/v1/chat/completions \
</Tabs> </Tabs>
## Advanced
### Set category-based thresholds.
Lakera has 2 categories for prompt_injection attacks:
- jailbreak
- prompt_injection
```yaml
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
guardrails:
- guardrail_name: "lakera-pre-guard"
litellm_params:
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
api_key: os.environ/LAKERA_API_KEY
api_base: os.environ/LAKERA_API_BASE
category_thresholds:
prompt_injection: 0.1
jailbreak: 0.1
```

View file

@ -25,7 +25,12 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import GuardrailItem, Role, default_roles from litellm.types.guardrails import (
GuardrailItem,
LakeraCategoryThresholds,
Role,
default_roles,
)
GUARDRAIL_NAME = "lakera_prompt_injection" GUARDRAIL_NAME = "lakera_prompt_injection"
@ -36,16 +41,11 @@ INPUT_POSITIONING_MAP = {
} }
class LakeraCategories(TypedDict, total=False):
jailbreak: float
prompt_injection: float
class lakeraAI_Moderation(CustomGuardrail): class lakeraAI_Moderation(CustomGuardrail):
def __init__( def __init__(
self, self,
moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel", moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel",
category_thresholds: Optional[LakeraCategories] = None, category_thresholds: Optional[LakeraCategoryThresholds] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
**kwargs, **kwargs,
@ -72,7 +72,7 @@ class lakeraAI_Moderation(CustomGuardrail):
if self.category_thresholds is not None: if self.category_thresholds is not None:
if category_scores is not None: if category_scores is not None:
typed_cat_scores = LakeraCategories(**category_scores) typed_cat_scores = LakeraCategoryThresholds(**category_scores)
if ( if (
"jailbreak" in typed_cat_scores "jailbreak" in typed_cat_scores
and "jailbreak" in self.category_thresholds and "jailbreak" in self.category_thresholds
@ -219,6 +219,8 @@ class lakeraAI_Moderation(CustomGuardrail):
text = "\n".join(data["input"]) text = "\n".join(data["input"])
_json_data = json.dumps({"input": text}) _json_data = json.dumps({"input": text})
verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data)
# https://platform.lakera.ai/account/api-keys # https://platform.lakera.ai/account/api-keys
""" """
@ -288,7 +290,18 @@ class lakeraAI_Moderation(CustomGuardrail):
"pass_through_endpoint", "pass_through_endpoint",
], ],
) -> Optional[Union[Exception, str, Dict]]: ) -> Optional[Union[Exception, str, Dict]]:
if self.moderation_check == "in_parallel": from litellm.types.guardrails import GuardrailEventHooks
if self.event_hook is None:
if self.moderation_check == "in_parallel":
return None
if (
self.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.pre_call
)
is not True
):
return None return None
return await self._check( return await self._check(

View file

@ -12,6 +12,7 @@ from litellm.types.guardrails import (
Guardrail, Guardrail,
GuardrailItem, GuardrailItem,
GuardrailItemSpec, GuardrailItemSpec,
LakeraCategoryThresholds,
LitellmParams, LitellmParams,
guardrailConfig, guardrailConfig,
) )
@ -99,6 +100,15 @@ def init_guardrails_v2(all_guardrails: dict):
api_base=litellm_params_data["api_base"], api_base=litellm_params_data["api_base"],
) )
if (
"category_thresholds" in litellm_params_data
and litellm_params_data["category_thresholds"]
):
lakera_category_thresholds = LakeraCategoryThresholds(
**litellm_params_data["category_thresholds"]
)
litellm_params["category_thresholds"] = lakera_category_thresholds
if litellm_params["api_key"]: if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"): if litellm_params["api_key"].startswith("os.environ/"):
litellm_params["api_key"] = litellm.get_secret( litellm_params["api_key"] = litellm.get_secret(
@ -134,6 +144,7 @@ def init_guardrails_v2(all_guardrails: dict):
api_key=litellm_params["api_key"], api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"], guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"], event_hook=litellm_params["mode"],
category_thresholds=litellm_params.get("category_thresholds"),
) )
litellm.callbacks.append(_lakera_callback) # type: ignore litellm.callbacks.append(_lakera_callback) # type: ignore

View file

@ -12,4 +12,7 @@ guardrails:
mode: "during_call" mode: "during_call"
api_key: os.environ/LAKERA_API_KEY api_key: os.environ/LAKERA_API_KEY
api_base: os.environ/LAKERA_API_BASE api_base: os.environ/LAKERA_API_BASE
category_thresholds:
prompt_injection: 0.1
jailbreak: 0.1

View file

@ -66,11 +66,17 @@ class GuardrailItem(BaseModel):
# Define the TypedDicts # Define the TypedDicts
class LitellmParams(TypedDict): class LakeraCategoryThresholds(TypedDict, total=False):
prompt_injection: float
jailbreak: float
class LitellmParams(TypedDict, total=False):
guardrail: str guardrail: str
mode: str mode: str
api_key: str api_key: str
api_base: Optional[str] api_base: Optional[str]
category_thresholds: Optional[LakeraCategoryThresholds]
class Guardrail(TypedDict): class Guardrail(TypedDict):