mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
support lakera ai category thresholds
This commit is contained in:
parent
30da63bd4f
commit
8d2c529e55
5 changed files with 80 additions and 15 deletions
|
@ -4,8 +4,8 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Lakera AI
|
# Lakera AI
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
## 1. Define Guardrails on your LiteLLM config.yaml
|
### 1. Define Guardrails on your LiteLLM config.yaml
|
||||||
|
|
||||||
Define your guardrails under the `guardrails` section
|
Define your guardrails under the `guardrails` section
|
||||||
```yaml
|
```yaml
|
||||||
|
@ -22,23 +22,29 @@ guardrails:
|
||||||
mode: "during_call"
|
mode: "during_call"
|
||||||
api_key: os.environ/LAKERA_API_KEY
|
api_key: os.environ/LAKERA_API_KEY
|
||||||
api_base: os.environ/LAKERA_API_BASE
|
api_base: os.environ/LAKERA_API_BASE
|
||||||
|
- guardrail_name: "lakera-pre-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
||||||
|
mode: "pre_call"
|
||||||
|
api_key: os.environ/LAKERA_API_KEY
|
||||||
|
api_base: os.environ/LAKERA_API_BASE
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Supported values for `mode`
|
#### Supported values for `mode`
|
||||||
|
|
||||||
- `pre_call` Run **before** LLM call, on **input**
|
- `pre_call` Run **before** LLM call, on **input**
|
||||||
- `post_call` Run **after** LLM call, on **input & output**
|
- `post_call` Run **after** LLM call, on **input & output**
|
||||||
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
|
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
|
||||||
|
|
||||||
## 2. Start LiteLLM Gateway
|
### 2. Start LiteLLM Gateway
|
||||||
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
litellm --config config.yaml --detailed_debug
|
litellm --config config.yaml --detailed_debug
|
||||||
```
|
```
|
||||||
|
|
||||||
## 3. Test request
|
### 3. Test request
|
||||||
|
|
||||||
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)**
|
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)**
|
||||||
|
|
||||||
|
@ -120,4 +126,30 @@ curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
## Advanced
|
||||||
|
### Set category-based thresholds.
|
||||||
|
|
||||||
|
Lakera has 2 categories for prompt_injection attacks:
|
||||||
|
- jailbreak
|
||||||
|
- prompt_injection
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: fake-openai-endpoint
|
||||||
|
litellm_params:
|
||||||
|
model: openai/fake
|
||||||
|
api_key: fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
|
guardrails:
|
||||||
|
- guardrail_name: "lakera-pre-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
||||||
|
mode: "during_call"
|
||||||
|
api_key: os.environ/LAKERA_API_KEY
|
||||||
|
api_base: os.environ/LAKERA_API_BASE
|
||||||
|
category_thresholds:
|
||||||
|
prompt_injection: 0.1
|
||||||
|
jailbreak: 0.1
|
||||||
|
|
||||||
|
```
|
|
@ -25,7 +25,12 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
from litellm.types.guardrails import GuardrailItem, Role, default_roles
|
from litellm.types.guardrails import (
|
||||||
|
GuardrailItem,
|
||||||
|
LakeraCategoryThresholds,
|
||||||
|
Role,
|
||||||
|
default_roles,
|
||||||
|
)
|
||||||
|
|
||||||
GUARDRAIL_NAME = "lakera_prompt_injection"
|
GUARDRAIL_NAME = "lakera_prompt_injection"
|
||||||
|
|
||||||
|
@ -36,16 +41,11 @@ INPUT_POSITIONING_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LakeraCategories(TypedDict, total=False):
|
|
||||||
jailbreak: float
|
|
||||||
prompt_injection: float
|
|
||||||
|
|
||||||
|
|
||||||
class lakeraAI_Moderation(CustomGuardrail):
|
class lakeraAI_Moderation(CustomGuardrail):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel",
|
moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel",
|
||||||
category_thresholds: Optional[LakeraCategories] = None,
|
category_thresholds: Optional[LakeraCategoryThresholds] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -72,7 +72,7 @@ class lakeraAI_Moderation(CustomGuardrail):
|
||||||
|
|
||||||
if self.category_thresholds is not None:
|
if self.category_thresholds is not None:
|
||||||
if category_scores is not None:
|
if category_scores is not None:
|
||||||
typed_cat_scores = LakeraCategories(**category_scores)
|
typed_cat_scores = LakeraCategoryThresholds(**category_scores)
|
||||||
if (
|
if (
|
||||||
"jailbreak" in typed_cat_scores
|
"jailbreak" in typed_cat_scores
|
||||||
and "jailbreak" in self.category_thresholds
|
and "jailbreak" in self.category_thresholds
|
||||||
|
@ -219,6 +219,8 @@ class lakeraAI_Moderation(CustomGuardrail):
|
||||||
text = "\n".join(data["input"])
|
text = "\n".join(data["input"])
|
||||||
_json_data = json.dumps({"input": text})
|
_json_data = json.dumps({"input": text})
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data)
|
||||||
|
|
||||||
# https://platform.lakera.ai/account/api-keys
|
# https://platform.lakera.ai/account/api-keys
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -288,7 +290,18 @@ class lakeraAI_Moderation(CustomGuardrail):
|
||||||
"pass_through_endpoint",
|
"pass_through_endpoint",
|
||||||
],
|
],
|
||||||
) -> Optional[Union[Exception, str, Dict]]:
|
) -> Optional[Union[Exception, str, Dict]]:
|
||||||
if self.moderation_check == "in_parallel":
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
if self.event_hook is None:
|
||||||
|
if self.moderation_check == "in_parallel":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.should_run_guardrail(
|
||||||
|
data=data, event_type=GuardrailEventHooks.pre_call
|
||||||
|
)
|
||||||
|
is not True
|
||||||
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return await self._check(
|
return await self._check(
|
||||||
|
|
|
@ -12,6 +12,7 @@ from litellm.types.guardrails import (
|
||||||
Guardrail,
|
Guardrail,
|
||||||
GuardrailItem,
|
GuardrailItem,
|
||||||
GuardrailItemSpec,
|
GuardrailItemSpec,
|
||||||
|
LakeraCategoryThresholds,
|
||||||
LitellmParams,
|
LitellmParams,
|
||||||
guardrailConfig,
|
guardrailConfig,
|
||||||
)
|
)
|
||||||
|
@ -99,6 +100,15 @@ def init_guardrails_v2(all_guardrails: dict):
|
||||||
api_base=litellm_params_data["api_base"],
|
api_base=litellm_params_data["api_base"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"category_thresholds" in litellm_params_data
|
||||||
|
and litellm_params_data["category_thresholds"]
|
||||||
|
):
|
||||||
|
lakera_category_thresholds = LakeraCategoryThresholds(
|
||||||
|
**litellm_params_data["category_thresholds"]
|
||||||
|
)
|
||||||
|
litellm_params["category_thresholds"] = lakera_category_thresholds
|
||||||
|
|
||||||
if litellm_params["api_key"]:
|
if litellm_params["api_key"]:
|
||||||
if litellm_params["api_key"].startswith("os.environ/"):
|
if litellm_params["api_key"].startswith("os.environ/"):
|
||||||
litellm_params["api_key"] = litellm.get_secret(
|
litellm_params["api_key"] = litellm.get_secret(
|
||||||
|
@ -134,6 +144,7 @@ def init_guardrails_v2(all_guardrails: dict):
|
||||||
api_key=litellm_params["api_key"],
|
api_key=litellm_params["api_key"],
|
||||||
guardrail_name=guardrail["guardrail_name"],
|
guardrail_name=guardrail["guardrail_name"],
|
||||||
event_hook=litellm_params["mode"],
|
event_hook=litellm_params["mode"],
|
||||||
|
category_thresholds=litellm_params.get("category_thresholds"),
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_lakera_callback) # type: ignore
|
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -12,4 +12,7 @@ guardrails:
|
||||||
mode: "during_call"
|
mode: "during_call"
|
||||||
api_key: os.environ/LAKERA_API_KEY
|
api_key: os.environ/LAKERA_API_KEY
|
||||||
api_base: os.environ/LAKERA_API_BASE
|
api_base: os.environ/LAKERA_API_BASE
|
||||||
|
category_thresholds:
|
||||||
|
prompt_injection: 0.1
|
||||||
|
jailbreak: 0.1
|
||||||
|
|
|
@ -66,11 +66,17 @@ class GuardrailItem(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
# Define the TypedDicts
|
# Define the TypedDicts
|
||||||
class LitellmParams(TypedDict):
|
class LakeraCategoryThresholds(TypedDict, total=False):
|
||||||
|
prompt_injection: float
|
||||||
|
jailbreak: float
|
||||||
|
|
||||||
|
|
||||||
|
class LitellmParams(TypedDict, total=False):
|
||||||
guardrail: str
|
guardrail: str
|
||||||
mode: str
|
mode: str
|
||||||
api_key: str
|
api_key: str
|
||||||
api_base: Optional[str]
|
api_base: Optional[str]
|
||||||
|
category_thresholds: Optional[LakeraCategoryThresholds]
|
||||||
|
|
||||||
|
|
||||||
class Guardrail(TypedDict):
|
class Guardrail(TypedDict):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue