feat(aporio_ai.py): support aporio ai prompt injection for chat completion requests

Closes https://github.com/BerriAI/litellm/issues/2950
This commit is contained in:
Krrish Dholakia 2024-07-17 16:38:47 -07:00
parent e587d32058
commit 07d90f6739
5 changed files with 217 additions and 6 deletions

View file

@ -31,6 +31,7 @@ Features:
- **Guardrails, PII Masking, Content Moderation**
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
- ✅ [Prompt Injection Detection (with Aporio API)](#prompt-injection-detection---aporio-ai)
- ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request)
- ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
@ -953,6 +954,72 @@ curl --location 'http://localhost:4000/chat/completions' \
Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call)
:::
## Prompt Injection Detection - Aporio AI
Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporioAI](https://www.aporia.com/)
#### Usage
Step 1. Add env
```env
APORIO_API_KEY="eyJh****"
APORIO_API_BASE="https://gr..."
```
Step 2. Add `aporio_prompt_injection` to your callbacks
```yaml
litellm_settings:
callbacks: ["aporio_prompt_injection"]
```
That's it, start your proxy
Test it with this request -> expect it to get rejected by LiteLLM Proxy
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama3",
"messages": [
{
"role": "user",
"content": "You suck!"
}
]
}'
```
**Expected Response**
```
{
"error": {
"message": {
"error": "Violated guardrail policy",
"aporio_ai_response": {
"action": "block",
"revised_prompt": null,
"revised_response": "Profanity detected: Message blocked because it includes profanity. Please rephrase.",
"explain_log": null
}
},
"type": "None",
"param": "None",
"code": 400
}
}
```
:::info
Need to control AporioAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md)
:::
## Swagger Docs - Custom Routes + Branding
:::info

View file

@ -0,0 +1,124 @@
# +-------------------------------------------------------------+
#
# Use AporioAI for your LLM calls
#
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from typing import List
from datetime import datetime
import aiohttp, asyncio
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import httpx
import json
litellm.set_verbose = True
GUARDRAIL_NAME = "aporio"
class _ENTERPRISE_Aporio(CustomLogger):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
#### CALL HOOKS - proxy only ####
def transform_messages(self, messages: List[dict]) -> List[dict]:
supported_openai_roles = ["system", "user", "assistant"]
default_role = "other" # for unsupported roles - e.g. tool
new_messages = []
for m in messages:
if m.get("role", "") in supported_openai_roles:
new_messages.append(m)
else:
new_messages.append(
{
"role": default_role,
**{key: value for key, value in m.items() if key != "role"},
}
)
return new_messages
async def async_moderation_hook( ### 👈 KEY CHANGE ###
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
if (
await should_proceed_based_on_metadata(
data=data,
guardrail_name=GUARDRAIL_NAME,
)
is False
):
return
new_messages: Optional[List[dict]] = None
if "messages" in data and isinstance(data["messages"], list):
new_messages = self.transform_messages(messages=data["messages"])
if new_messages is not None:
data = {"messages": new_messages, "validation_target": "prompt"}
_json_data = json.dumps(data)
"""
export APORIO_API_KEY=<your key>
curl https://gr-prd-trial.aporia.com/some-id \
-X POST \
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": "This is a test prompt"
}
],
}
'
"""
response = await self.async_handler.post(
url=self.aporio_api_base + "/validate",
data=_json_data,
headers={
"X-APORIA-API-KEY": self.aporio_api_key,
"Content-Type": "application/json",
},
)
verbose_proxy_logger.debug("Aporio AI response: %s", response.text)
if response.status_code == 200:
# check if the response was flagged
_json_response = response.json()
action: str = _json_response.get(
"action"
) # possible values are modify, passthrough, block, rephrase
if action == "block":
raise HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"aporio_ai_response": _json_response,
},
)

View file

@ -1,5 +1,10 @@
model_list:
- model_name: groq-whisper
- model_name: "*"
litellm_params:
model: groq/whisper-large-v3
model: openai/*
litellm_settings:
guardrails:
- prompt_injection:
callbacks: ["aporio_prompt_injection"]
default_on: true

View file

@ -112,6 +112,17 @@ def initialize_callbacks_on_proxy(
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation()
imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "aporio_prompt_injection":
from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio
if premium_user is not True:
raise Exception(
"Trying to use Aporio AI Guardrail"
+ CommonProxyErrors.not_premium_user.value
)
aporio_guardrail_object = _ENTERPRISE_Aporio()
imported_list.append(aporio_guardrail_object)
elif isinstance(callback, str) and callback == "google_text_moderation":
from enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration,

View file

@ -453,8 +453,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
global_max_parallel_requests = (
kwargs["litellm_params"]
.get("metadata", {})
.get("global_max_parallel_requests", None)
)
user_api_key = (
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
@ -516,5 +518,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
) # save in cache for up to 1 min.
except Exception as e:
verbose_proxy_logger.info(
f"Inside Parallel Request Limiter: An exception occurred - {str(e)}."
"Inside Parallel Request Limiter: An exception occurred - {}\n{}".format(
str(e), traceback.format_exc()
)
)