mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(aporio_ai.py): support aporio ai prompt injection for chat completion requests
Closes https://github.com/BerriAI/litellm/issues/2950
This commit is contained in:
parent
e587d32058
commit
07d90f6739
5 changed files with 217 additions and 6 deletions
|
@ -31,6 +31,7 @@ Features:
|
||||||
- **Guardrails, PII Masking, Content Moderation**
|
- **Guardrails, PII Masking, Content Moderation**
|
||||||
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
|
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
|
||||||
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
|
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
|
||||||
|
- ✅ [Prompt Injection Detection (with Aporio API)](#prompt-injection-detection---aporio-ai)
|
||||||
- ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request)
|
- ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request)
|
||||||
- ✅ Reject calls from Blocked User list
|
- ✅ Reject calls from Blocked User list
|
||||||
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
||||||
|
@ -953,6 +954,72 @@ curl --location 'http://localhost:4000/chat/completions' \
|
||||||
Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call)
|
Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call)
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
## Prompt Injection Detection - Aporio AI
|
||||||
|
|
||||||
|
Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporioAI](https://www.aporia.com/)
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
Step 1. Add env
|
||||||
|
|
||||||
|
```env
|
||||||
|
APORIO_API_KEY="eyJh****"
|
||||||
|
APORIO_API_BASE="https://gr..."
|
||||||
|
```
|
||||||
|
|
||||||
|
Step 2. Add `aporio_prompt_injection` to your callbacks
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["aporio_prompt_injection"]
|
||||||
|
```
|
||||||
|
|
||||||
|
That's it, start your proxy
|
||||||
|
|
||||||
|
Test it with this request -> expect it to get rejected by LiteLLM Proxy
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://localhost:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "You suck!"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": {
|
||||||
|
"error": "Violated guardrail policy",
|
||||||
|
"aporio_ai_response": {
|
||||||
|
"action": "block",
|
||||||
|
"revised_prompt": null,
|
||||||
|
"revised_response": "Profanity detected: Message blocked because it includes profanity. Please rephrase.",
|
||||||
|
"explain_log": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "None",
|
||||||
|
"param": "None",
|
||||||
|
"code": 400
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Need to control AporioAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md)
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
## Swagger Docs - Custom Routes + Branding
|
## Swagger Docs - Custom Routes + Branding
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
124
enterprise/enterprise_hooks/aporio_ai.py
Normal file
124
enterprise/enterprise_hooks/aporio_ai.py
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
#
|
||||||
|
# Use AporioAI for your LLM calls
|
||||||
|
#
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from typing import Optional, Literal, Union
|
||||||
|
import litellm, traceback, sys, uuid
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
from typing import List
|
||||||
|
from datetime import datetime
|
||||||
|
import aiohttp, asyncio
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
GUARDRAIL_NAME = "aporio"
|
||||||
|
|
||||||
|
|
||||||
|
class _ENTERPRISE_Aporio(CustomLogger):
|
||||||
|
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
|
||||||
|
self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
|
||||||
|
|
||||||
|
#### CALL HOOKS - proxy only ####
|
||||||
|
def transform_messages(self, messages: List[dict]) -> List[dict]:
|
||||||
|
supported_openai_roles = ["system", "user", "assistant"]
|
||||||
|
default_role = "other" # for unsupported roles - e.g. tool
|
||||||
|
new_messages = []
|
||||||
|
for m in messages:
|
||||||
|
if m.get("role", "") in supported_openai_roles:
|
||||||
|
new_messages.append(m)
|
||||||
|
else:
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": default_role,
|
||||||
|
**{key: value for key, value in m.items() if key != "role"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
|
):
|
||||||
|
|
||||||
|
if (
|
||||||
|
await should_proceed_based_on_metadata(
|
||||||
|
data=data,
|
||||||
|
guardrail_name=GUARDRAIL_NAME,
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
new_messages: Optional[List[dict]] = None
|
||||||
|
if "messages" in data and isinstance(data["messages"], list):
|
||||||
|
new_messages = self.transform_messages(messages=data["messages"])
|
||||||
|
|
||||||
|
if new_messages is not None:
|
||||||
|
data = {"messages": new_messages, "validation_target": "prompt"}
|
||||||
|
|
||||||
|
_json_data = json.dumps(data)
|
||||||
|
|
||||||
|
"""
|
||||||
|
export APORIO_API_KEY=<your key>
|
||||||
|
curl https://gr-prd-trial.aporia.com/some-id \
|
||||||
|
-X POST \
|
||||||
|
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "This is a test prompt"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
'
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=self.aporio_api_base + "/validate",
|
||||||
|
data=_json_data,
|
||||||
|
headers={
|
||||||
|
"X-APORIA-API-KEY": self.aporio_api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("Aporio AI response: %s", response.text)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# check if the response was flagged
|
||||||
|
_json_response = response.json()
|
||||||
|
action: str = _json_response.get(
|
||||||
|
"action"
|
||||||
|
) # possible values are modify, passthrough, block, rephrase
|
||||||
|
if action == "block":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Violated guardrail policy",
|
||||||
|
"aporio_ai_response": _json_response,
|
||||||
|
},
|
||||||
|
)
|
|
@ -1,5 +1,10 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: groq-whisper
|
- model_name: "*"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: groq/whisper-large-v3
|
model: openai/*
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
guardrails:
|
||||||
|
- prompt_injection:
|
||||||
|
callbacks: ["aporio_prompt_injection"]
|
||||||
|
default_on: true
|
||||||
|
|
|
@ -112,6 +112,17 @@ def initialize_callbacks_on_proxy(
|
||||||
|
|
||||||
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation()
|
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation()
|
||||||
imported_list.append(lakera_moderations_object)
|
imported_list.append(lakera_moderations_object)
|
||||||
|
elif isinstance(callback, str) and callback == "aporio_prompt_injection":
|
||||||
|
from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio
|
||||||
|
|
||||||
|
if premium_user is not True:
|
||||||
|
raise Exception(
|
||||||
|
"Trying to use Aporio AI Guardrail"
|
||||||
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
|
)
|
||||||
|
|
||||||
|
aporio_guardrail_object = _ENTERPRISE_Aporio()
|
||||||
|
imported_list.append(aporio_guardrail_object)
|
||||||
elif isinstance(callback, str) and callback == "google_text_moderation":
|
elif isinstance(callback, str) and callback == "google_text_moderation":
|
||||||
from enterprise.enterprise_hooks.google_text_moderation import (
|
from enterprise.enterprise_hooks.google_text_moderation import (
|
||||||
_ENTERPRISE_GoogleTextModeration,
|
_ENTERPRISE_GoogleTextModeration,
|
||||||
|
|
|
@ -453,8 +453,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
||||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
global_max_parallel_requests = (
|
||||||
"global_max_parallel_requests", None
|
kwargs["litellm_params"]
|
||||||
|
.get("metadata", {})
|
||||||
|
.get("global_max_parallel_requests", None)
|
||||||
)
|
)
|
||||||
user_api_key = (
|
user_api_key = (
|
||||||
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
|
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
|
||||||
|
@ -516,5 +518,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
) # save in cache for up to 1 min.
|
) # save in cache for up to 1 min.
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Inside Parallel Request Limiter: An exception occurred - {str(e)}."
|
"Inside Parallel Request Limiter: An exception occurred - {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue