Merge pull request #5288 from BerriAI/litellm_aporia_refactor

[Feat] V2 aporia guardrails litellm
This commit is contained in:
Ishaan Jaff 2024-08-19 20:41:45 -07:00 committed by GitHub
commit c82714757a
33 changed files with 1078 additions and 337 deletions

View file

@ -317,6 +317,10 @@ jobs:
-e OPENAI_API_KEY=$OPENAI_API_KEY \ -e OPENAI_API_KEY=$OPENAI_API_KEY \
-e LITELLM_LICENSE=$LITELLM_LICENSE \ -e LITELLM_LICENSE=$LITELLM_LICENSE \
-e OTEL_EXPORTER="in_memory" \ -e OTEL_EXPORTER="in_memory" \
-e APORIA_API_BASE_2=$APORIA_API_BASE_2 \
-e APORIA_API_KEY_2=$APORIA_API_KEY_2 \
-e APORIA_API_BASE_1=$APORIA_API_BASE_1 \
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
--name my-app \ --name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \ -v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
my-app:latest \ my-app:latest \

View file

@ -47,6 +47,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -36,7 +36,7 @@ Features:
- **Guardrails, PII Masking, Content Moderation** - **Guardrails, PII Masking, Content Moderation**
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation) - ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai) - ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
- ✅ [Prompt Injection Detection (with Aporio API)](#prompt-injection-detection---aporio-ai) - ✅ [Prompt Injection Detection (with Aporia API)](#prompt-injection-detection---aporia-ai)
- ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request) - ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request)
- ✅ Reject calls from Blocked User list - ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
@ -1035,9 +1035,9 @@ curl --location 'http://localhost:4000/chat/completions' \
Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call) Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call)
::: :::
## Prompt Injection Detection - Aporio AI ## Prompt Injection Detection - Aporia AI
Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporioAI](https://www.aporia.com/) Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporiaAI](https://www.aporia.com/)
#### Usage #### Usage
@ -1048,11 +1048,11 @@ APORIO_API_KEY="eyJh****"
APORIO_API_BASE="https://gr..." APORIO_API_BASE="https://gr..."
``` ```
Step 2. Add `aporio_prompt_injection` to your callbacks Step 2. Add `aporia_prompt_injection` to your callbacks
```yaml ```yaml
litellm_settings: litellm_settings:
callbacks: ["aporio_prompt_injection"] callbacks: ["aporia_prompt_injection"]
``` ```
That's it, start your proxy That's it, start your proxy
@ -1081,7 +1081,7 @@ curl --location 'http://localhost:4000/chat/completions' \
"error": { "error": {
"message": { "message": {
"error": "Violated guardrail policy", "error": "Violated guardrail policy",
"aporio_ai_response": { "aporia_ai_response": {
"action": "block", "action": "block",
"revised_prompt": null, "revised_prompt": null,
"revised_response": "Profanity detected: Message blocked because it includes profanity. Please rephrase.", "revised_response": "Profanity detected: Message blocked because it includes profanity. Please rephrase.",
@ -1097,7 +1097,7 @@ curl --location 'http://localhost:4000/chat/completions' \
:::info :::info
Need to control AporioAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md) Need to control AporiaAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md)
::: :::

View file

@ -1,18 +1,10 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# 🛡️ Guardrails # 🛡️ [Beta] Guardrails
Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy
:::info
✨ Enterprise Only Feature
Schedule a meeting with us to get an Enterprise License 👉 Talk to founders [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
## Quick Start ## Quick Start
### 1. Setup guardrails on litellm proxy config.yaml ### 1. Setup guardrails on litellm proxy config.yaml

View file

@ -1,98 +0,0 @@
import Image from '@theme/IdealImage';
# Split traffic betwen GPT-4 and Llama2 in Production!
In this tutorial, we'll walk through A/B testing between GPT-4 and Llama2 in production. We'll assume you've deployed Llama2 on Huggingface Inference Endpoints (but any of TogetherAI, Baseten, Ollama, Petals, Openrouter should work as well).
# Relevant Resources:
* 🚀 [Your production dashboard!](https://admin.litellm.ai/)
* [Deploying models on Huggingface](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint)
* [All supported providers on LiteLLM](https://docs.litellm.ai/docs/providers)
# Code Walkthrough
In production, we don't know if Llama2 is going to provide:
* good results
* quickly
### 💡 Route 20% traffic to Llama2
If Llama2 returns poor answers / is extremely slow, we want to roll-back this change, and use GPT-4 instead.
Instead of routing 100% of our traffic to Llama2, let's **start by routing 20% traffic** to it and see how it does.
```python
## route 20% of responses to Llama2
split_per_model = {
"gpt-4": 0.8,
"huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2
}
```
## 👨‍💻 Complete Code
### a) For Local
If we're testing this in a script - this is what our complete code looks like.
```python
from litellm import completion_with_split_tests
import os
## set ENV variables
os.environ["OPENAI_API_KEY"] = "openai key"
os.environ["HUGGINGFACE_API_KEY"] = "huggingface key"
## route 20% of responses to Llama2
split_per_model = {
"gpt-4": 0.8,
"huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2
}
messages = [{ "content": "Hello, how are you?","role": "user"}]
completion_with_split_tests(
models=split_per_model,
messages=messages,
)
```
### b) For Production
If we're in production, we don't want to keep going to code to change model/test details (prompt, split%, etc.) for our completion function and redeploying changes.
LiteLLM exposes a client dashboard to do this in a UI - and instantly updates our completion function in prod.
#### Relevant Code
```python
completion_with_split_tests(..., use_client=True, id="my-unique-id")
```
#### Complete Code
```python
from litellm import completion_with_split_tests
import os
## set ENV variables
os.environ["OPENAI_API_KEY"] = "openai key"
os.environ["HUGGINGFACE_API_KEY"] = "huggingface key"
## route 20% of responses to Llama2
split_per_model = {
"gpt-4": 0.8,
"huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2
}
messages = [{ "content": "Hello, how are you?","role": "user"}]
completion_with_split_tests(
models=split_per_model,
messages=messages,
use_client=True,
id="my-unique-id" # Auto-create this @ https://admin.litellm.ai/
)
```

View file

@ -0,0 +1,196 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Use LiteLLM AI Gateway with Aporia Guardrails
In this tutorial we will use LiteLLM Proxy with Aporia to detect PII in requests and profanity in responses
## 1. Setup guardrails on Aporia
### Create Aporia Projects
Create two projects on [Aporia](https://guardrails.aporia.com/)
1. Pre LLM API Call - Set all the policies you want to run on pre LLM API call
2. Post LLM API Call - Set all the policies you want to run post LLM API call
<Image img={require('../../img/aporia_projs.png')} />
### Pre-Call: Detect PII
Add the `PII - Prompt` to your Pre LLM API Call project
<Image img={require('../../img/aporia_pre.png')} />
### Post-Call: Detect Profanity in Responses
Add the `Toxicity - Response` to your Post LLM API Call project
<Image img={require('../../img/aporia_post.png')} />
## 2. Define Guardrails on your LiteLLM config.yaml
- Define your guardrails under the `guardrails` section and set `pre_call_guardrails` and `post_call_guardrails`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
guardrails:
- guardrail_name: "aporia-pre-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "lakera"
mode: "during_call"
api_key: os.environ/APORIA_API_KEY_1
api_base: os.environ/APORIA_API_BASE_1
- guardrail_name: "aporia-post-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "lakera"
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_2
api_base: os.environ/APORIA_API_BASE_2
```
### Supported values for `mode`
- `pre_call` Run **before** LLM call, on **input**
- `post_call` Run **after** LLM call, on **input & output**
- `during_call` Run **during** LLM call, on **input**
## 3. Start LiteLLM Gateway
```shell
litellm --config config.yaml --detailed_debug
```
## 4. Test request
<Tabs>
<TabItem label="Unsuccessful call" value = "not-allowed">
Expect this to fail since since `ishaan@berri.ai` in the request is PII
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
],
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
}'
```
Expected response on failure
```shell
{
"error": {
"message": {
"error": "Violated guardrail policy",
"aporia_ai_response": {
"action": "block",
"revised_prompt": null,
"revised_response": "Aporia detected and blocked PII",
"explain_log": null
}
},
"type": "None",
"param": "None",
"code": "400"
}
}
```
</TabItem>
<TabItem label="Successful Call " value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "hi what is the weather"}
],
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
}'
```
</TabItem>
</Tabs>
## Advanced
### Control Guardrails per Project (API Key)
Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project
- `pre_call_guardrails`: ["aporia-pre-guard"]
- `post_call_guardrails`: ["aporia-post-guard"]
**Step 1** Create Key with guardrail settings
<Tabs>
<TabItem value="/key/generate" label="/key/generate">
```shell
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-D '{
"pre_call_guardrails": ["aporia-pre-guard"],
"post_call_guardrails": ["aporia"]
}
}'
```
</TabItem>
<TabItem value="/key/update" label="/key/update">
```shell
curl --location 'http://0.0.0.0:4000/key/update' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"key": "sk-jNm1Zar7XfNdZXp49Z1kSQ",
"pre_call_guardrails": ["aporia"],
"post_call_guardrails": ["aporia"]
}
}'
```
</TabItem>
</Tabs>
**Step 2** Test it with new key
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-jNm1Zar7XfNdZXp49Z1kSQ' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "my email is ishaan@berri.ai"
}
]
}'
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 250 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 277 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 153 KiB

View file

@ -250,6 +250,7 @@ const sidebars = {
type: "category", type: "category",
label: "Tutorials", label: "Tutorials",
items: [ items: [
'tutorials/litellm_proxy_aporia',
'tutorials/azure_openai', 'tutorials/azure_openai',
'tutorials/instructor', 'tutorials/instructor',
"tutorials/gradio_integration", "tutorials/gradio_integration",

View file

@ -0,0 +1,208 @@
# +-------------------------------------------------------------+
#
# Use AporiaAI for your LLM calls
#
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional, Literal, Union, Any
import litellm, traceback, sys, uuid
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_guardrail import CustomGuardrail
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_str,
)
from typing import List
from datetime import datetime
import aiohttp, asyncio
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import httpx
import json
from litellm.types.guardrails import GuardrailEventHooks
litellm.set_verbose = True
GUARDRAIL_NAME = "aporia"
class _ENTERPRISE_Aporia(CustomGuardrail):
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
self.event_hook: GuardrailEventHooks
super().__init__(**kwargs)
#### CALL HOOKS - proxy only ####
def transform_messages(self, messages: List[dict]) -> List[dict]:
supported_openai_roles = ["system", "user", "assistant"]
default_role = "other" # for unsupported roles - e.g. tool
new_messages = []
for m in messages:
if m.get("role", "") in supported_openai_roles:
new_messages.append(m)
else:
new_messages.append(
{
"role": default_role,
**{key: value for key, value in m.items() if key != "role"},
}
)
return new_messages
async def prepare_aporia_request(
self, new_messages: List[dict], response_string: Optional[str] = None
) -> dict:
data: dict[str, Any] = {}
if new_messages is not None:
data["messages"] = new_messages
if response_string is not None:
data["response"] = response_string
# Set validation target
if new_messages and response_string:
data["validation_target"] = "both"
elif new_messages:
data["validation_target"] = "prompt"
elif response_string:
data["validation_target"] = "response"
verbose_proxy_logger.debug("Aporia AI request: %s", data)
return data
async def make_aporia_api_request(
self, new_messages: List[dict], response_string: Optional[str] = None
):
data = await self.prepare_aporia_request(
new_messages=new_messages, response_string=response_string
)
_json_data = json.dumps(data)
"""
export APORIO_API_KEY=<your key>
curl https://gr-prd-trial.aporia.com/some-id \
-X POST \
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": "This is a test prompt"
}
],
}
'
"""
response = await self.async_handler.post(
url=self.aporia_api_base + "/validate",
data=_json_data,
headers={
"X-APORIA-API-KEY": self.aporia_api_key,
"Content-Type": "application/json",
},
)
verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
if response.status_code == 200:
# check if the response was flagged
_json_response = response.json()
action: str = _json_response.get(
"action"
) # possible values are modify, passthrough, block, rephrase
if action == "block":
raise HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"aporia_ai_response": _json_response,
},
)
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
"""
Use this for the post call moderation with Guardrails
"""
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
if response_str is not None:
await self.make_aporia_api_request(
response_string=response_str, new_messages=data.get("messages", [])
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
pass
async def async_moderation_hook( ### 👈 KEY CHANGE ###
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
# old implementation - backwards compatibility
if (
await should_proceed_based_on_metadata(
data=data,
guardrail_name=GUARDRAIL_NAME,
)
is False
):
return
new_messages: Optional[List[dict]] = None
if "messages" in data and isinstance(data["messages"], list):
new_messages = self.transform_messages(messages=data["messages"])
if new_messages is not None:
await self.make_aporia_api_request(new_messages=new_messages)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
else:
verbose_proxy_logger.warning(
"Aporia AI: not running guardrail. No messages in data"
)
pass

View file

@ -1,124 +0,0 @@
# +-------------------------------------------------------------+
#
# Use AporioAI for your LLM calls
#
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from typing import List
from datetime import datetime
import aiohttp, asyncio
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import httpx
import json
litellm.set_verbose = True
GUARDRAIL_NAME = "aporio"
class _ENTERPRISE_Aporio(CustomLogger):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
#### CALL HOOKS - proxy only ####
def transform_messages(self, messages: List[dict]) -> List[dict]:
supported_openai_roles = ["system", "user", "assistant"]
default_role = "other" # for unsupported roles - e.g. tool
new_messages = []
for m in messages:
if m.get("role", "") in supported_openai_roles:
new_messages.append(m)
else:
new_messages.append(
{
"role": default_role,
**{key: value for key, value in m.items() if key != "role"},
}
)
return new_messages
async def async_moderation_hook( ### 👈 KEY CHANGE ###
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
if (
await should_proceed_based_on_metadata(
data=data,
guardrail_name=GUARDRAIL_NAME,
)
is False
):
return
new_messages: Optional[List[dict]] = None
if "messages" in data and isinstance(data["messages"], list):
new_messages = self.transform_messages(messages=data["messages"])
if new_messages is not None:
data = {"messages": new_messages, "validation_target": "prompt"}
_json_data = json.dumps(data)
"""
export APORIO_API_KEY=<your key>
curl https://gr-prd-trial.aporia.com/some-id \
-X POST \
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": "This is a test prompt"
}
],
}
'
"""
response = await self.async_handler.post(
url=self.aporio_api_base + "/validate",
data=_json_data,
headers={
"X-APORIA-API-KEY": self.aporio_api_key,
"Content-Type": "application/json",
},
)
verbose_proxy_logger.debug("Aporio AI response: %s", response.text)
if response.status_code == 200:
# check if the response was flagged
_json_response = response.json()
action: str = _json_response.get(
"action"
) # possible values are modify, passthrough, block, rephrase
if action == "block":
raise HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"aporio_ai_response": _json_response,
},
)

View file

@ -90,6 +90,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -0,0 +1,32 @@
from typing import Literal
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.guardrails import GuardrailEventHooks
class CustomGuardrail(CustomLogger):
def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs):
self.guardrail_name = guardrail_name
self.event_hook: GuardrailEventHooks = event_hook
super().__init__(**kwargs)
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s",
self.guardrail_name,
event_type,
self.event_hook,
)
metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or []
if self.guardrail_name not in requested_guardrails:
return False
if self.event_hook != event_type:
return False
return True

View file

@ -122,6 +122,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -1,4 +1,12 @@
from typing import Any from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING:
from litellm import ModelResponse as _ModelResponse
LiteLLMModelResponse = _ModelResponse
else:
LiteLLMModelResponse = Any
import litellm import litellm
@ -20,3 +28,21 @@ def convert_litellm_response_object_to_dict(response_obj: Any) -> dict:
# If it's not a LiteLLM type, return the object as is # If it's not a LiteLLM type, return the object as is
return dict(response_obj) return dict(response_obj)
def convert_litellm_response_object_to_str(
response_obj: Union[Any, LiteLLMModelResponse]
) -> Optional[str]:
"""
Get the string of the response object from LiteLLM
"""
if isinstance(response_obj, litellm.ModelResponse):
response_str = ""
for choice in response_obj.choices:
if isinstance(choice, litellm.Choices):
if choice.message.content and isinstance(choice.message.content, str):
response_str += choice.message.content
return response_str
return None

View file

@ -118,17 +118,19 @@ def initialize_callbacks_on_proxy(
**init_params **init_params
) )
imported_list.append(lakera_moderations_object) imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "aporio_prompt_injection": elif isinstance(callback, str) and callback == "aporia_prompt_injection":
from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
_ENTERPRISE_Aporia,
)
if premium_user is not True: if premium_user is not True:
raise Exception( raise Exception(
"Trying to use Aporio AI Guardrail" "Trying to use Aporia AI Guardrail"
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
) )
aporio_guardrail_object = _ENTERPRISE_Aporio() aporia_guardrail_object = _ENTERPRISE_Aporia()
imported_list.append(aporio_guardrail_object) imported_list.append(aporia_guardrail_object)
elif isinstance(callback, str) and callback == "google_text_moderation": elif isinstance(callback, str) and callback == "google_text_moderation":
from enterprise.enterprise_hooks.google_text_moderation import ( from enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration, _ENTERPRISE_GoogleTextModeration,
@ -295,3 +297,21 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str,
headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens
return headers return headers
def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]:
_metadata = request_data.get("metadata", None) or {}
if "applied_guardrails" in _metadata:
return {
"x-litellm-applied-guardrails": ",".join(_metadata["applied_guardrails"]),
}
return None
def add_guardrail_to_applied_guardrails_header(request_data: Dict, guardrail_name: str):
_metadata = request_data.get("metadata", None) or {}
if "applied_guardrails" in _metadata:
_metadata["applied_guardrails"].append(guardrail_name)
else:
_metadata["applied_guardrails"] = [guardrail_name]

View file

@ -40,6 +40,7 @@ class MyCustomHandler(
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -9,3 +9,16 @@ litellm_settings:
cache: true cache: true
callbacks: ["otel"] callbacks: ["otel"]
guardrails:
- guardrail_name: "aporia-pre-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_1
api_base: os.environ/APORIA_API_BASE_1
- guardrail_name: "aporia-post-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_2
api_base: os.environ/APORIA_API_BASE_2

View file

@ -37,32 +37,35 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b
requested_callback_names = [] requested_callback_names = []
# get guardrail configs from `init_guardrails.py` # v1 implementation of this
# for all requested guardrails -> get their associated callbacks if isinstance(request_guardrails, dict):
for _guardrail_name, should_run in request_guardrails.items():
if should_run is False:
verbose_proxy_logger.debug(
"Guardrail %s skipped because request set to False",
_guardrail_name,
)
continue
# lookup the guardrail in guardrail_name_config_map # get guardrail configs from `init_guardrails.py`
guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[ # for all requested guardrails -> get their associated callbacks
_guardrail_name for _guardrail_name, should_run in request_guardrails.items():
] if should_run is False:
verbose_proxy_logger.debug(
"Guardrail %s skipped because request set to False",
_guardrail_name,
)
continue
guardrail_callbacks = guardrail_item.callbacks # lookup the guardrail in guardrail_name_config_map
requested_callback_names.extend(guardrail_callbacks) guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
_guardrail_name
]
verbose_proxy_logger.debug( guardrail_callbacks = guardrail_item.callbacks
"requested_callback_names %s", requested_callback_names requested_callback_names.extend(guardrail_callbacks)
)
if guardrail_name in requested_callback_names:
return True
# Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } } verbose_proxy_logger.debug(
return False "requested_callback_names %s", requested_callback_names
)
if guardrail_name in requested_callback_names:
return True
# Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } }
return False
return True return True

View file

@ -0,0 +1,212 @@
# +-------------------------------------------------------------+
#
# Use AporiaAI for your LLM calls
#
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import os
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import json
import sys
import traceback
import uuid
from datetime import datetime
from typing import Any, List, Literal, Optional, Union
import aiohttp
import httpx
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_str,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import GuardrailEventHooks
litellm.set_verbose = True
GUARDRAIL_NAME = "aporia"
class _ENTERPRISE_Aporia(CustomGuardrail):
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
self.event_hook: GuardrailEventHooks
super().__init__(**kwargs)
#### CALL HOOKS - proxy only ####
def transform_messages(self, messages: List[dict]) -> List[dict]:
supported_openai_roles = ["system", "user", "assistant"]
default_role = "other" # for unsupported roles - e.g. tool
new_messages = []
for m in messages:
if m.get("role", "") in supported_openai_roles:
new_messages.append(m)
else:
new_messages.append(
{
"role": default_role,
**{key: value for key, value in m.items() if key != "role"},
}
)
return new_messages
async def prepare_aporia_request(
self, new_messages: List[dict], response_string: Optional[str] = None
) -> dict:
data: dict[str, Any] = {}
if new_messages is not None:
data["messages"] = new_messages
if response_string is not None:
data["response"] = response_string
# Set validation target
if new_messages and response_string:
data["validation_target"] = "both"
elif new_messages:
data["validation_target"] = "prompt"
elif response_string:
data["validation_target"] = "response"
verbose_proxy_logger.debug("Aporia AI request: %s", data)
return data
async def make_aporia_api_request(
self, new_messages: List[dict], response_string: Optional[str] = None
):
data = await self.prepare_aporia_request(
new_messages=new_messages, response_string=response_string
)
_json_data = json.dumps(data)
"""
export APORIO_API_KEY=<your key>
curl https://gr-prd-trial.aporia.com/some-id \
-X POST \
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": "This is a test prompt"
}
],
}
'
"""
response = await self.async_handler.post(
url=self.aporia_api_base + "/validate",
data=_json_data,
headers={
"X-APORIA-API-KEY": self.aporia_api_key,
"Content-Type": "application/json",
},
)
verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
if response.status_code == 200:
# check if the response was flagged
_json_response = response.json()
action: str = _json_response.get(
"action"
) # possible values are modify, passthrough, block, rephrase
if action == "block":
raise HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"aporia_ai_response": _json_response,
},
)
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
"""
Use this for the post call moderation with Guardrails
"""
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
if response_str is not None:
await self.make_aporia_api_request(
response_string=response_str, new_messages=data.get("messages", [])
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
pass
async def async_moderation_hook( ### 👈 KEY CHANGE ###
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
# old implementation - backwards compatibility
if (
await should_proceed_based_on_metadata(
data=data,
guardrail_name=GUARDRAIL_NAME,
)
is False
):
return
new_messages: Optional[List[dict]] = None
if "messages" in data and isinstance(data["messages"], list):
new_messages = self.transform_messages(messages=data["messages"])
if new_messages is not None:
await self.make_aporia_api_request(new_messages=new_messages)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
else:
verbose_proxy_logger.warning(
"Aporia AI: not running guardrail. No messages in data"
)
pass

View file

@ -1,12 +1,20 @@
import traceback import traceback
from typing import Dict, List from typing import Dict, List, Literal
from pydantic import BaseModel, RootModel from pydantic import BaseModel, RootModel
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
# v2 implementation
from litellm.types.guardrails import (
Guardrail,
GuardrailItem,
GuardrailItemSpec,
LitellmParams,
guardrailConfig,
)
all_guardrails: List[GuardrailItem] = [] all_guardrails: List[GuardrailItem] = []
@ -66,3 +74,68 @@ def initialize_guardrails(
"error initializing guardrails {}".format(str(e)) "error initializing guardrails {}".format(str(e))
) )
raise e raise e
"""
Map guardrail_name: <pre_call>, <post_call>, during_call
"""
def init_guardrails_v2(all_guardrails: dict):
# Convert the loaded data to the TypedDict structure
guardrail_list = []
# Parse each guardrail and replace environment variables
for guardrail in all_guardrails:
# Init litellm params for guardrail
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
litellm_params = LitellmParams(
guardrail=litellm_params_data["guardrail"],
mode=litellm_params_data["mode"],
api_key=litellm_params_data["api_key"],
api_base=litellm_params_data["api_base"],
)
if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"):
litellm_params["api_key"] = litellm.get_secret(
litellm_params["api_key"]
)
if litellm_params["api_base"]:
if litellm_params["api_base"].startswith("os.environ/"):
litellm_params["api_base"] = litellm.get_secret(
litellm_params["api_base"]
)
# Init guardrail CustomLoggerClass
if litellm_params["guardrail"] == "aporia":
from guardrail_hooks.aporia_ai import _ENTERPRISE_Aporia
_aporia_callback = _ENTERPRISE_Aporia(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aporia_callback) # type: ignore
elif litellm_params["guardrail"] == "lakera":
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
litellm.callbacks.append(_lakera_callback) # type: ignore
parsed_guardrail = Guardrail(
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
)
guardrail_list.append(parsed_guardrail)
guardrail_name = guardrail["guardrail_name"]
# pretty print guardrail_list in green
print(f"\nGuardrail List:{guardrail_list}\n") # noqa

View file

@ -1,11 +1,16 @@
from litellm.integrations.custom_logger import CustomLogger import sys
from litellm.caching import DualCache import traceback
from litellm.proxy._types import UserAPIKeyAuth import uuid
import litellm, traceback, sys, uuid
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from typing import Optional from typing import Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_AzureContentSafety( class _PROXY_AzureContentSafety(
CustomLogger CustomLogger
@ -15,12 +20,12 @@ class _PROXY_AzureContentSafety(
def __init__(self, endpoint, api_key, thresholds=None): def __init__(self, endpoint, api_key, thresholds=None):
try: try:
from azure.ai.contentsafety.aio import ContentSafetyClient from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
from azure.ai.contentsafety.models import ( from azure.ai.contentsafety.models import (
TextCategory,
AnalyzeTextOptions, AnalyzeTextOptions,
AnalyzeTextOutputType, AnalyzeTextOutputType,
TextCategory,
) )
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError from azure.core.exceptions import HttpResponseError
except Exception as e: except Exception as e:
raise Exception( raise Exception(
@ -132,6 +137,7 @@ class _PROXY_AzureContentSafety(
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
): ):

View file

@ -254,7 +254,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
return None return None
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, user_api_key_dict: UserAPIKeyAuth, response self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
): ):
try: try:
if isinstance(response, ModelResponse): if isinstance(response, ModelResponse):
@ -287,7 +287,9 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
return response return response
return await super().async_post_call_success_hook( return await super().async_post_call_success_hook(
user_api_key_dict, response data=data,
user_api_key_dict=user_api_key_dict,
response=response,
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.exception(

View file

@ -322,6 +322,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse], response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
): ):

View file

@ -316,9 +316,20 @@ async def add_litellm_data_to_request(
for k, v in callback_settings_obj.callback_vars.items(): for k, v in callback_settings_obj.callback_vars.items():
data[k] = v data[k] = v
# Guardrails
move_guardrails_to_metadata(
data=data, _metadata_variable_name=_metadata_variable_name
)
return data return data
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str):
if "guardrails" in data:
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
del data["guardrails"]
def add_provider_specific_headers_to_request( def add_provider_specific_headers_to_request(
data: dict, data: dict,
headers: dict, headers: dict,

View file

@ -1,50 +1,20 @@
model_list: model_list:
- model_name: gpt-4 - model_name: fake-openai-endpoint
litellm_params: litellm_params:
model: openai/fake model: openai/fake
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
model_info:
access_groups: ["beta-models"]
- model_name: fireworks-llama-v3-70b-instruct
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS"
model_info:
access_groups: ["beta-models"]
- model_name: "*"
litellm_params:
model: "*"
- model_name: "*"
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: mistral-small-latest
litellm_params:
model: mistral/mistral-small-latest
api_key: "os.environ/MISTRAL_API_KEY"
- model_name: bedrock-anthropic
litellm_params:
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
- model_name: gemini-1.5-pro-001
litellm_params:
model: vertex_ai_beta/gemini-1.5-pro-001
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json"
# Add path to service account.json
default_vertex_config: guardrails:
vertex_project: "adroit-crow-413218" - guardrail_name: "aporia-pre-guard"
vertex_location: "us-central1" litellm_params:
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_1
general_settings: api_base: os.environ/APORIA_API_BASE_1
master_key: sk-1234 - guardrail_name: "aporia-post-guard"
alerting: ["slack"] litellm_params:
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
litellm_settings: mode: "post_call"
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}] api_key: os.environ/APORIA_API_KEY_2
success_callback: ["langfuse", "prometheus"] api_base: os.environ/APORIA_API_BASE_2
langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]

View file

@ -149,6 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
show_missing_vars_in_env, show_missing_vars_in_env,
) )
from litellm.proxy.common_utils.callback_utils import ( from litellm.proxy.common_utils.callback_utils import (
get_applied_guardrails_header,
get_remaining_tokens_and_requests_from_request_data, get_remaining_tokens_and_requests_from_request_data,
initialize_callbacks_on_proxy, initialize_callbacks_on_proxy,
) )
@ -168,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
) )
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails from litellm.proxy.guardrails.init_guardrails import (
init_guardrails_v2,
initialize_guardrails,
)
from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.prompt_injection_detection import ( from litellm.proxy.hooks.prompt_injection_detection import (
@ -539,6 +543,10 @@ def get_custom_headers(
) )
headers.update(remaining_tokens_header) headers.update(remaining_tokens_header)
applied_guardrails = get_applied_guardrails_header(request_data)
if applied_guardrails:
headers.update(applied_guardrails)
try: try:
return { return {
key: value for key, value in headers.items() if value not in exclude_values key: value for key, value in headers.items() if value not in exclude_values
@ -1937,6 +1945,11 @@ class ProxyConfig:
async_only_mode=True # only init async clients async_only_mode=True # only init async clients
), ),
) # type:ignore ) # type:ignore
# Guardrail settings
guardrails_v2 = config.get("guardrails", None)
if guardrails_v2:
init_guardrails_v2(all_guardrails=guardrails_v2)
return router, router.get_model_list(), general_settings return router, router.get_model_list(), general_settings
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo: def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
@ -3139,7 +3152,7 @@ async def chat_completion(
### CALL HOOKS ### - modify outgoing data ### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook( response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response data=data, user_api_key_dict=user_api_key_dict, response=response
) )
hidden_params = ( hidden_params = (
@ -3353,6 +3366,11 @@ async def completion(
media_type="text/event-stream", media_type="text/event-stream",
headers=custom_headers, headers=custom_headers,
) )
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
fastapi_response.headers.update( fastapi_response.headers.update(
get_custom_headers( get_custom_headers(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,

View file

@ -432,12 +432,11 @@ class ProxyLogging:
""" """
Runs the CustomLogger's async_moderation_hook() Runs the CustomLogger's async_moderation_hook()
""" """
new_data = safe_deep_copy(data)
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): if isinstance(callback, CustomLogger):
await callback.async_moderation_hook( await callback.async_moderation_hook(
data=new_data, data=data,
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
call_type=call_type, call_type=call_type,
) )
@ -717,6 +716,7 @@ class ProxyLogging:
async def post_call_success_hook( async def post_call_success_hook(
self, self,
data: dict,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse], response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
): ):
@ -738,7 +738,9 @@ class ProxyLogging:
_callback = callback # type: ignore _callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger): if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook( await _callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response user_api_key_dict=user_api_key_dict,
data=data,
response=response,
) )
except Exception as e: except Exception as e:
raise e raise e

View file

@ -1,8 +1,13 @@
# What is this? # What is this?
## Unit test for azure content safety ## Unit test for azure content safety
import sys, os, asyncio, time, random import asyncio
from datetime import datetime import os
import random
import sys
import time
import traceback import traceback
from datetime import datetime
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import HTTPException from fastapi import HTTPException
@ -13,11 +18,12 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
from litellm import Router, mock_completion from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.utils import ProxyLogging
@pytest.mark.asyncio @pytest.mark.asyncio
@ -177,7 +183,13 @@ async def test_strict_output_filtering_01():
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await azure_content_safety.async_post_call_success_hook( await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [
{"role": "system", "content": "You are an helpfull assistant"}
]
},
response=response,
) )
assert exc_info.value.detail["source"] == "output" assert exc_info.value.detail["source"] == "output"
@ -216,7 +228,11 @@ async def test_strict_output_filtering_02():
) )
await azure_content_safety.async_post_call_success_hook( await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
) )
@ -251,7 +267,11 @@ async def test_loose_output_filtering_01():
) )
await azure_content_safety.async_post_call_success_hook( await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
) )
@ -286,5 +306,9 @@ async def test_loose_output_filtering_02():
) )
await azure_content_safety.async_post_call_success_hook( await azure_content_safety.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
) )

View file

@ -88,7 +88,11 @@ async def test_output_parsing():
mock_response="Hello <PERSON>! How can I assist you today?", mock_response="Hello <PERSON>! How can I assist you today?",
) )
new_response = await pii_masking.async_post_call_success_hook( new_response = await pii_masking.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response user_api_key_dict=UserAPIKeyAuth(),
data={
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
},
response=response,
) )
assert ( assert (

View file

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Dict, List, Optional from typing import Dict, List, Optional, TypedDict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
@ -63,3 +63,26 @@ class GuardrailItem(BaseModel):
enabled_roles=enabled_roles, enabled_roles=enabled_roles,
callback_args=callback_args, callback_args=callback_args,
) )
# Define the TypedDicts
class LitellmParams(TypedDict):
guardrail: str
mode: str
api_key: str
api_base: Optional[str]
class Guardrail(TypedDict):
guardrail_name: str
litellm_params: LitellmParams
class guardrailConfig(TypedDict):
guardrails: List[Guardrail]
class GuardrailEventHooks(str, Enum):
pre_call = "pre_call"
post_call = "post_call"
during_call = "during_call"

View file

@ -0,0 +1,118 @@
import pytest
import asyncio
import aiohttp, openai
from openai import OpenAI, AsyncOpenAI
from typing import Optional, List, Union
import uuid
async def chat_completion(
session,
key,
messages,
model: Union[str, List] = "gpt-4",
guardrails: Optional[List] = None,
):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
data = {
"model": model,
"messages": messages,
"guardrails": [
"aporia-post-guard",
"aporia-pre-guard",
], # default guardrails for all tests
}
if guardrails is not None:
data["guardrails"] = guardrails
print("data=", data)
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
return response_text
# response headers
response_headers = response.headers
print("response headers=", response_headers)
return await response.json(), response_headers
@pytest.mark.asyncio
async def test_llm_guard_triggered_safe_request():
"""
- Tests a request where no content mod is triggered
- Assert that the guardrails applied are returned in the response headers
"""
async with aiohttp.ClientSession() as session:
response, headers = await chat_completion(
session,
"sk-1234",
model="fake-openai-endpoint",
messages=[{"role": "user", "content": f"Hello what's the weather"}],
)
await asyncio.sleep(3)
print("response=", response, "response headers", headers)
assert "x-litellm-applied-guardrails" in headers
assert (
headers["x-litellm-applied-guardrails"]
== "aporia-pre-guard,aporia-post-guard"
)
@pytest.mark.asyncio
async def test_llm_guard_triggered():
"""
- Tests a request where no content mod is triggered
- Assert that the guardrails applied are returned in the response headers
"""
async with aiohttp.ClientSession() as session:
try:
response, headers = await chat_completion(
session,
"sk-1234",
model="fake-openai-endpoint",
messages=[
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
],
)
pytest.fail("Should have thrown an exception")
except Exception as e:
print(e)
assert "Aporia detected and blocked PII" in str(e)
@pytest.mark.asyncio
async def test_no_llm_guard_triggered():
"""
- Tests a request where no content mod is triggered
- Assert that the guardrails applied are returned in the response headers
"""
async with aiohttp.ClientSession() as session:
response, headers = await chat_completion(
session,
"sk-1234",
model="fake-openai-endpoint",
messages=[{"role": "user", "content": f"Hello what's the weather"}],
guardrails=[],
)
await asyncio.sleep(3)
print("response=", response, "response headers", headers)
assert "x-litellm-applied-guardrails" not in headers