mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #5288 from BerriAI/litellm_aporia_refactor
[Feat] V2 aporia guardrails litellm
This commit is contained in:
commit
c82714757a
33 changed files with 1078 additions and 337 deletions
|
@ -317,6 +317,10 @@ jobs:
|
||||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||||
-e LITELLM_LICENSE=$LITELLM_LICENSE \
|
-e LITELLM_LICENSE=$LITELLM_LICENSE \
|
||||||
-e OTEL_EXPORTER="in_memory" \
|
-e OTEL_EXPORTER="in_memory" \
|
||||||
|
-e APORIA_API_BASE_2=$APORIA_API_BASE_2 \
|
||||||
|
-e APORIA_API_KEY_2=$APORIA_API_KEY_2 \
|
||||||
|
-e APORIA_API_BASE_1=$APORIA_API_BASE_1 \
|
||||||
|
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
|
||||||
--name my-app \
|
--name my-app \
|
||||||
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
|
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
|
||||||
my-app:latest \
|
my-app:latest \
|
||||||
|
|
|
@ -47,6 +47,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response,
|
response,
|
||||||
):
|
):
|
||||||
|
|
|
@ -36,7 +36,7 @@ Features:
|
||||||
- **Guardrails, PII Masking, Content Moderation**
|
- **Guardrails, PII Masking, Content Moderation**
|
||||||
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
|
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
|
||||||
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
|
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
|
||||||
- ✅ [Prompt Injection Detection (with Aporio API)](#prompt-injection-detection---aporio-ai)
|
- ✅ [Prompt Injection Detection (with Aporia API)](#prompt-injection-detection---aporia-ai)
|
||||||
- ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request)
|
- ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request)
|
||||||
- ✅ Reject calls from Blocked User list
|
- ✅ Reject calls from Blocked User list
|
||||||
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
||||||
|
@ -1035,9 +1035,9 @@ curl --location 'http://localhost:4000/chat/completions' \
|
||||||
Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call)
|
Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call)
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Prompt Injection Detection - Aporio AI
|
## Prompt Injection Detection - Aporia AI
|
||||||
|
|
||||||
Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporioAI](https://www.aporia.com/)
|
Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporiaAI](https://www.aporia.com/)
|
||||||
|
|
||||||
#### Usage
|
#### Usage
|
||||||
|
|
||||||
|
@ -1048,11 +1048,11 @@ APORIO_API_KEY="eyJh****"
|
||||||
APORIO_API_BASE="https://gr..."
|
APORIO_API_BASE="https://gr..."
|
||||||
```
|
```
|
||||||
|
|
||||||
Step 2. Add `aporio_prompt_injection` to your callbacks
|
Step 2. Add `aporia_prompt_injection` to your callbacks
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["aporio_prompt_injection"]
|
callbacks: ["aporia_prompt_injection"]
|
||||||
```
|
```
|
||||||
|
|
||||||
That's it, start your proxy
|
That's it, start your proxy
|
||||||
|
@ -1081,7 +1081,7 @@ curl --location 'http://localhost:4000/chat/completions' \
|
||||||
"error": {
|
"error": {
|
||||||
"message": {
|
"message": {
|
||||||
"error": "Violated guardrail policy",
|
"error": "Violated guardrail policy",
|
||||||
"aporio_ai_response": {
|
"aporia_ai_response": {
|
||||||
"action": "block",
|
"action": "block",
|
||||||
"revised_prompt": null,
|
"revised_prompt": null,
|
||||||
"revised_response": "Profanity detected: Message blocked because it includes profanity. Please rephrase.",
|
"revised_response": "Profanity detected: Message blocked because it includes profanity. Please rephrase.",
|
||||||
|
@ -1097,7 +1097,7 @@ curl --location 'http://localhost:4000/chat/completions' \
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
||||||
Need to control AporioAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md)
|
Need to control AporiaAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md)
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,10 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# 🛡️ Guardrails
|
# 🛡️ [Beta] Guardrails
|
||||||
|
|
||||||
Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy
|
Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy
|
||||||
|
|
||||||
:::info
|
|
||||||
|
|
||||||
✨ Enterprise Only Feature
|
|
||||||
|
|
||||||
Schedule a meeting with us to get an Enterprise License 👉 Talk to founders [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
### 1. Setup guardrails on litellm proxy config.yaml
|
### 1. Setup guardrails on litellm proxy config.yaml
|
||||||
|
|
|
@ -1,98 +0,0 @@
|
||||||
import Image from '@theme/IdealImage';
|
|
||||||
|
|
||||||
# Split traffic betwen GPT-4 and Llama2 in Production!
|
|
||||||
In this tutorial, we'll walk through A/B testing between GPT-4 and Llama2 in production. We'll assume you've deployed Llama2 on Huggingface Inference Endpoints (but any of TogetherAI, Baseten, Ollama, Petals, Openrouter should work as well).
|
|
||||||
|
|
||||||
|
|
||||||
# Relevant Resources:
|
|
||||||
|
|
||||||
* 🚀 [Your production dashboard!](https://admin.litellm.ai/)
|
|
||||||
|
|
||||||
|
|
||||||
* [Deploying models on Huggingface](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint)
|
|
||||||
* [All supported providers on LiteLLM](https://docs.litellm.ai/docs/providers)
|
|
||||||
|
|
||||||
# Code Walkthrough
|
|
||||||
|
|
||||||
In production, we don't know if Llama2 is going to provide:
|
|
||||||
* good results
|
|
||||||
* quickly
|
|
||||||
|
|
||||||
### 💡 Route 20% traffic to Llama2
|
|
||||||
If Llama2 returns poor answers / is extremely slow, we want to roll-back this change, and use GPT-4 instead.
|
|
||||||
|
|
||||||
Instead of routing 100% of our traffic to Llama2, let's **start by routing 20% traffic** to it and see how it does.
|
|
||||||
|
|
||||||
```python
|
|
||||||
## route 20% of responses to Llama2
|
|
||||||
split_per_model = {
|
|
||||||
"gpt-4": 0.8,
|
|
||||||
"huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 👨💻 Complete Code
|
|
||||||
|
|
||||||
### a) For Local
|
|
||||||
If we're testing this in a script - this is what our complete code looks like.
|
|
||||||
```python
|
|
||||||
from litellm import completion_with_split_tests
|
|
||||||
import os
|
|
||||||
|
|
||||||
## set ENV variables
|
|
||||||
os.environ["OPENAI_API_KEY"] = "openai key"
|
|
||||||
os.environ["HUGGINGFACE_API_KEY"] = "huggingface key"
|
|
||||||
|
|
||||||
## route 20% of responses to Llama2
|
|
||||||
split_per_model = {
|
|
||||||
"gpt-4": 0.8,
|
|
||||||
"huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
|
||||||
|
|
||||||
completion_with_split_tests(
|
|
||||||
models=split_per_model,
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### b) For Production
|
|
||||||
|
|
||||||
If we're in production, we don't want to keep going to code to change model/test details (prompt, split%, etc.) for our completion function and redeploying changes.
|
|
||||||
|
|
||||||
LiteLLM exposes a client dashboard to do this in a UI - and instantly updates our completion function in prod.
|
|
||||||
|
|
||||||
#### Relevant Code
|
|
||||||
|
|
||||||
```python
|
|
||||||
completion_with_split_tests(..., use_client=True, id="my-unique-id")
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Complete Code
|
|
||||||
|
|
||||||
```python
|
|
||||||
from litellm import completion_with_split_tests
|
|
||||||
import os
|
|
||||||
|
|
||||||
## set ENV variables
|
|
||||||
os.environ["OPENAI_API_KEY"] = "openai key"
|
|
||||||
os.environ["HUGGINGFACE_API_KEY"] = "huggingface key"
|
|
||||||
|
|
||||||
## route 20% of responses to Llama2
|
|
||||||
split_per_model = {
|
|
||||||
"gpt-4": 0.8,
|
|
||||||
"huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
|
||||||
|
|
||||||
completion_with_split_tests(
|
|
||||||
models=split_per_model,
|
|
||||||
messages=messages,
|
|
||||||
use_client=True,
|
|
||||||
id="my-unique-id" # Auto-create this @ https://admin.litellm.ai/
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
196
docs/my-website/docs/tutorials/litellm_proxy_aporia.md
Normal file
196
docs/my-website/docs/tutorials/litellm_proxy_aporia.md
Normal file
|
@ -0,0 +1,196 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# Use LiteLLM AI Gateway with Aporia Guardrails
|
||||||
|
|
||||||
|
In this tutorial we will use LiteLLM Proxy with Aporia to detect PII in requests and profanity in responses
|
||||||
|
|
||||||
|
## 1. Setup guardrails on Aporia
|
||||||
|
|
||||||
|
### Create Aporia Projects
|
||||||
|
|
||||||
|
Create two projects on [Aporia](https://guardrails.aporia.com/)
|
||||||
|
|
||||||
|
1. Pre LLM API Call - Set all the policies you want to run on pre LLM API call
|
||||||
|
2. Post LLM API Call - Set all the policies you want to run post LLM API call
|
||||||
|
|
||||||
|
|
||||||
|
<Image img={require('../../img/aporia_projs.png')} />
|
||||||
|
|
||||||
|
|
||||||
|
### Pre-Call: Detect PII
|
||||||
|
|
||||||
|
Add the `PII - Prompt` to your Pre LLM API Call project
|
||||||
|
|
||||||
|
<Image img={require('../../img/aporia_pre.png')} />
|
||||||
|
|
||||||
|
### Post-Call: Detect Profanity in Responses
|
||||||
|
|
||||||
|
Add the `Toxicity - Response` to your Post LLM API Call project
|
||||||
|
|
||||||
|
<Image img={require('../../img/aporia_post.png')} />
|
||||||
|
|
||||||
|
|
||||||
|
## 2. Define Guardrails on your LiteLLM config.yaml
|
||||||
|
|
||||||
|
- Define your guardrails under the `guardrails` section and set `pre_call_guardrails` and `post_call_guardrails`
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-3.5-turbo
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
guardrails:
|
||||||
|
- guardrail_name: "aporia-pre-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: aporia # supported values: "aporia", "lakera"
|
||||||
|
mode: "during_call"
|
||||||
|
api_key: os.environ/APORIA_API_KEY_1
|
||||||
|
api_base: os.environ/APORIA_API_BASE_1
|
||||||
|
- guardrail_name: "aporia-post-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: aporia # supported values: "aporia", "lakera"
|
||||||
|
mode: "post_call"
|
||||||
|
api_key: os.environ/APORIA_API_KEY_2
|
||||||
|
api_base: os.environ/APORIA_API_BASE_2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported values for `mode`
|
||||||
|
|
||||||
|
- `pre_call` Run **before** LLM call, on **input**
|
||||||
|
- `post_call` Run **after** LLM call, on **input & output**
|
||||||
|
- `during_call` Run **during** LLM call, on **input**
|
||||||
|
|
||||||
|
## 3. Start LiteLLM Gateway
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. Test request
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem label="Unsuccessful call" value = "not-allowed">
|
||||||
|
|
||||||
|
Expect this to fail since since `ishaan@berri.ai` in the request is PII
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
|
||||||
|
],
|
||||||
|
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected response on failure
|
||||||
|
|
||||||
|
```shell
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": {
|
||||||
|
"error": "Violated guardrail policy",
|
||||||
|
"aporia_ai_response": {
|
||||||
|
"action": "block",
|
||||||
|
"revised_prompt": null,
|
||||||
|
"revised_response": "Aporia detected and blocked PII",
|
||||||
|
"explain_log": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "None",
|
||||||
|
"param": "None",
|
||||||
|
"code": "400"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem label="Successful Call " value = "allowed">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hi what is the weather"}
|
||||||
|
],
|
||||||
|
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Advanced
|
||||||
|
### Control Guardrails per Project (API Key)
|
||||||
|
|
||||||
|
Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project
|
||||||
|
- `pre_call_guardrails`: ["aporia-pre-guard"]
|
||||||
|
- `post_call_guardrails`: ["aporia-post-guard"]
|
||||||
|
|
||||||
|
**Step 1** Create Key with guardrail settings
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="/key/generate" label="/key/generate">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-D '{
|
||||||
|
"pre_call_guardrails": ["aporia-pre-guard"],
|
||||||
|
"post_call_guardrails": ["aporia"]
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="/key/update" label="/key/update">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/key/update' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"key": "sk-jNm1Zar7XfNdZXp49Z1kSQ",
|
||||||
|
"pre_call_guardrails": ["aporia"],
|
||||||
|
"post_call_guardrails": ["aporia"]
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
**Step 2** Test it with new key
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-jNm1Zar7XfNdZXp49Z1kSQ' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "my email is ishaan@berri.ai"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
BIN
docs/my-website/img/aporia_post.png
Normal file
BIN
docs/my-website/img/aporia_post.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 250 KiB |
BIN
docs/my-website/img/aporia_pre.png
Normal file
BIN
docs/my-website/img/aporia_pre.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 277 KiB |
BIN
docs/my-website/img/aporia_projs.png
Normal file
BIN
docs/my-website/img/aporia_projs.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 153 KiB |
|
@ -250,6 +250,7 @@ const sidebars = {
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Tutorials",
|
label: "Tutorials",
|
||||||
items: [
|
items: [
|
||||||
|
'tutorials/litellm_proxy_aporia',
|
||||||
'tutorials/azure_openai',
|
'tutorials/azure_openai',
|
||||||
'tutorials/instructor',
|
'tutorials/instructor',
|
||||||
"tutorials/gradio_integration",
|
"tutorials/gradio_integration",
|
||||||
|
|
208
enterprise/enterprise_hooks/aporia_ai.py
Normal file
208
enterprise/enterprise_hooks/aporia_ai.py
Normal file
|
@ -0,0 +1,208 @@
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
#
|
||||||
|
# Use AporiaAI for your LLM calls
|
||||||
|
#
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from typing import Optional, Literal, Union, Any
|
||||||
|
import litellm, traceback, sys, uuid
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
from litellm.litellm_core_utils.logging_utils import (
|
||||||
|
convert_litellm_response_object_to_str,
|
||||||
|
)
|
||||||
|
from typing import List
|
||||||
|
from datetime import datetime
|
||||||
|
import aiohttp, asyncio
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
GUARDRAIL_NAME = "aporia"
|
||||||
|
|
||||||
|
|
||||||
|
class _ENTERPRISE_Aporia(CustomGuardrail):
|
||||||
|
def __init__(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
|
||||||
|
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
|
||||||
|
self.event_hook: GuardrailEventHooks
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
#### CALL HOOKS - proxy only ####
|
||||||
|
def transform_messages(self, messages: List[dict]) -> List[dict]:
|
||||||
|
supported_openai_roles = ["system", "user", "assistant"]
|
||||||
|
default_role = "other" # for unsupported roles - e.g. tool
|
||||||
|
new_messages = []
|
||||||
|
for m in messages:
|
||||||
|
if m.get("role", "") in supported_openai_roles:
|
||||||
|
new_messages.append(m)
|
||||||
|
else:
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": default_role,
|
||||||
|
**{key: value for key, value in m.items() if key != "role"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
async def prepare_aporia_request(
|
||||||
|
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||||
|
) -> dict:
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
if new_messages is not None:
|
||||||
|
data["messages"] = new_messages
|
||||||
|
if response_string is not None:
|
||||||
|
data["response"] = response_string
|
||||||
|
|
||||||
|
# Set validation target
|
||||||
|
if new_messages and response_string:
|
||||||
|
data["validation_target"] = "both"
|
||||||
|
elif new_messages:
|
||||||
|
data["validation_target"] = "prompt"
|
||||||
|
elif response_string:
|
||||||
|
data["validation_target"] = "response"
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("Aporia AI request: %s", data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def make_aporia_api_request(
|
||||||
|
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||||
|
):
|
||||||
|
data = await self.prepare_aporia_request(
|
||||||
|
new_messages=new_messages, response_string=response_string
|
||||||
|
)
|
||||||
|
|
||||||
|
_json_data = json.dumps(data)
|
||||||
|
|
||||||
|
"""
|
||||||
|
export APORIO_API_KEY=<your key>
|
||||||
|
curl https://gr-prd-trial.aporia.com/some-id \
|
||||||
|
-X POST \
|
||||||
|
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "This is a test prompt"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
'
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=self.aporia_api_base + "/validate",
|
||||||
|
data=_json_data,
|
||||||
|
headers={
|
||||||
|
"X-APORIA-API-KEY": self.aporia_api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# check if the response was flagged
|
||||||
|
_json_response = response.json()
|
||||||
|
action: str = _json_response.get(
|
||||||
|
"action"
|
||||||
|
) # possible values are modify, passthrough, block, rephrase
|
||||||
|
if action == "block":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Violated guardrail policy",
|
||||||
|
"aporia_ai_response": _json_response,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
add_guardrail_to_applied_guardrails_header,
|
||||||
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
"""
|
||||||
|
Use this for the post call moderation with Guardrails
|
||||||
|
"""
|
||||||
|
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||||
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||||
|
if response_str is not None:
|
||||||
|
await self.make_aporia_api_request(
|
||||||
|
response_string=response_str, new_messages=data.get("messages", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
add_guardrail_to_applied_guardrails_header(
|
||||||
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
|
):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
add_guardrail_to_applied_guardrails_header,
|
||||||
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||||
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
# old implementation - backwards compatibility
|
||||||
|
if (
|
||||||
|
await should_proceed_based_on_metadata(
|
||||||
|
data=data,
|
||||||
|
guardrail_name=GUARDRAIL_NAME,
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
new_messages: Optional[List[dict]] = None
|
||||||
|
if "messages" in data and isinstance(data["messages"], list):
|
||||||
|
new_messages = self.transform_messages(messages=data["messages"])
|
||||||
|
|
||||||
|
if new_messages is not None:
|
||||||
|
await self.make_aporia_api_request(new_messages=new_messages)
|
||||||
|
add_guardrail_to_applied_guardrails_header(
|
||||||
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"Aporia AI: not running guardrail. No messages in data"
|
||||||
|
)
|
||||||
|
pass
|
|
@ -1,124 +0,0 @@
|
||||||
# +-------------------------------------------------------------+
|
|
||||||
#
|
|
||||||
# Use AporioAI for your LLM calls
|
|
||||||
#
|
|
||||||
# +-------------------------------------------------------------+
|
|
||||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
|
||||||
|
|
||||||
import sys, os
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
from typing import Optional, Literal, Union
|
|
||||||
import litellm, traceback, sys, uuid
|
|
||||||
from litellm.caching import DualCache
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
|
||||||
from typing import List
|
|
||||||
from datetime import datetime
|
|
||||||
import aiohttp, asyncio
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
|
||||||
import httpx
|
|
||||||
import json
|
|
||||||
|
|
||||||
litellm.set_verbose = True
|
|
||||||
|
|
||||||
GUARDRAIL_NAME = "aporio"
|
|
||||||
|
|
||||||
|
|
||||||
class _ENTERPRISE_Aporio(CustomLogger):
|
|
||||||
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
|
|
||||||
self.async_handler = AsyncHTTPHandler(
|
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
|
||||||
)
|
|
||||||
self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
|
|
||||||
self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
|
|
||||||
|
|
||||||
#### CALL HOOKS - proxy only ####
|
|
||||||
def transform_messages(self, messages: List[dict]) -> List[dict]:
|
|
||||||
supported_openai_roles = ["system", "user", "assistant"]
|
|
||||||
default_role = "other" # for unsupported roles - e.g. tool
|
|
||||||
new_messages = []
|
|
||||||
for m in messages:
|
|
||||||
if m.get("role", "") in supported_openai_roles:
|
|
||||||
new_messages.append(m)
|
|
||||||
else:
|
|
||||||
new_messages.append(
|
|
||||||
{
|
|
||||||
"role": default_role,
|
|
||||||
**{key: value for key, value in m.items() if key != "role"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_messages
|
|
||||||
|
|
||||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
|
||||||
self,
|
|
||||||
data: dict,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
|
||||||
):
|
|
||||||
|
|
||||||
if (
|
|
||||||
await should_proceed_based_on_metadata(
|
|
||||||
data=data,
|
|
||||||
guardrail_name=GUARDRAIL_NAME,
|
|
||||||
)
|
|
||||||
is False
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
new_messages: Optional[List[dict]] = None
|
|
||||||
if "messages" in data and isinstance(data["messages"], list):
|
|
||||||
new_messages = self.transform_messages(messages=data["messages"])
|
|
||||||
|
|
||||||
if new_messages is not None:
|
|
||||||
data = {"messages": new_messages, "validation_target": "prompt"}
|
|
||||||
|
|
||||||
_json_data = json.dumps(data)
|
|
||||||
|
|
||||||
"""
|
|
||||||
export APORIO_API_KEY=<your key>
|
|
||||||
curl https://gr-prd-trial.aporia.com/some-id \
|
|
||||||
-X POST \
|
|
||||||
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "This is a test prompt"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
'
|
|
||||||
"""
|
|
||||||
|
|
||||||
response = await self.async_handler.post(
|
|
||||||
url=self.aporio_api_base + "/validate",
|
|
||||||
data=_json_data,
|
|
||||||
headers={
|
|
||||||
"X-APORIA-API-KEY": self.aporio_api_key,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug("Aporio AI response: %s", response.text)
|
|
||||||
if response.status_code == 200:
|
|
||||||
# check if the response was flagged
|
|
||||||
_json_response = response.json()
|
|
||||||
action: str = _json_response.get(
|
|
||||||
"action"
|
|
||||||
) # possible values are modify, passthrough, block, rephrase
|
|
||||||
if action == "block":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail={
|
|
||||||
"error": "Violated guardrail policy",
|
|
||||||
"aporio_ai_response": _json_response,
|
|
||||||
},
|
|
||||||
)
|
|
|
@ -90,6 +90,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response,
|
response,
|
||||||
):
|
):
|
||||||
|
|
32
litellm/integrations/custom_guardrail.py
Normal file
32
litellm/integrations/custom_guardrail.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
|
||||||
|
class CustomGuardrail(CustomLogger):
|
||||||
|
|
||||||
|
def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs):
|
||||||
|
self.guardrail_name = guardrail_name
|
||||||
|
self.event_hook: GuardrailEventHooks = event_hook
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s",
|
||||||
|
self.guardrail_name,
|
||||||
|
event_type,
|
||||||
|
self.event_hook,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = data.get("metadata") or {}
|
||||||
|
requested_guardrails = metadata.get("guardrails") or []
|
||||||
|
|
||||||
|
if self.guardrail_name not in requested_guardrails:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.event_hook != event_type:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
|
@ -122,6 +122,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response,
|
response,
|
||||||
):
|
):
|
||||||
|
|
|
@ -1,4 +1,12 @@
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm import ModelResponse as _ModelResponse
|
||||||
|
|
||||||
|
LiteLLMModelResponse = _ModelResponse
|
||||||
|
else:
|
||||||
|
LiteLLMModelResponse = Any
|
||||||
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
@ -20,3 +28,21 @@ def convert_litellm_response_object_to_dict(response_obj: Any) -> dict:
|
||||||
|
|
||||||
# If it's not a LiteLLM type, return the object as is
|
# If it's not a LiteLLM type, return the object as is
|
||||||
return dict(response_obj)
|
return dict(response_obj)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_litellm_response_object_to_str(
|
||||||
|
response_obj: Union[Any, LiteLLMModelResponse]
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the string of the response object from LiteLLM
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(response_obj, litellm.ModelResponse):
|
||||||
|
response_str = ""
|
||||||
|
for choice in response_obj.choices:
|
||||||
|
if isinstance(choice, litellm.Choices):
|
||||||
|
if choice.message.content and isinstance(choice.message.content, str):
|
||||||
|
response_str += choice.message.content
|
||||||
|
return response_str
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -118,17 +118,19 @@ def initialize_callbacks_on_proxy(
|
||||||
**init_params
|
**init_params
|
||||||
)
|
)
|
||||||
imported_list.append(lakera_moderations_object)
|
imported_list.append(lakera_moderations_object)
|
||||||
elif isinstance(callback, str) and callback == "aporio_prompt_injection":
|
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
|
||||||
from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio
|
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
|
||||||
|
_ENTERPRISE_Aporia,
|
||||||
|
)
|
||||||
|
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Trying to use Aporio AI Guardrail"
|
"Trying to use Aporia AI Guardrail"
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
)
|
)
|
||||||
|
|
||||||
aporio_guardrail_object = _ENTERPRISE_Aporio()
|
aporia_guardrail_object = _ENTERPRISE_Aporia()
|
||||||
imported_list.append(aporio_guardrail_object)
|
imported_list.append(aporia_guardrail_object)
|
||||||
elif isinstance(callback, str) and callback == "google_text_moderation":
|
elif isinstance(callback, str) and callback == "google_text_moderation":
|
||||||
from enterprise.enterprise_hooks.google_text_moderation import (
|
from enterprise.enterprise_hooks.google_text_moderation import (
|
||||||
_ENTERPRISE_GoogleTextModeration,
|
_ENTERPRISE_GoogleTextModeration,
|
||||||
|
@ -295,3 +297,21 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str,
|
||||||
headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens
|
headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]:
|
||||||
|
_metadata = request_data.get("metadata", None) or {}
|
||||||
|
if "applied_guardrails" in _metadata:
|
||||||
|
return {
|
||||||
|
"x-litellm-applied-guardrails": ",".join(_metadata["applied_guardrails"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def add_guardrail_to_applied_guardrails_header(request_data: Dict, guardrail_name: str):
|
||||||
|
_metadata = request_data.get("metadata", None) or {}
|
||||||
|
if "applied_guardrails" in _metadata:
|
||||||
|
_metadata["applied_guardrails"].append(guardrail_name)
|
||||||
|
else:
|
||||||
|
_metadata["applied_guardrails"] = [guardrail_name]
|
||||||
|
|
|
@ -40,6 +40,7 @@ class MyCustomHandler(
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response,
|
response,
|
||||||
):
|
):
|
||||||
|
|
|
@ -9,3 +9,16 @@ litellm_settings:
|
||||||
cache: true
|
cache: true
|
||||||
callbacks: ["otel"]
|
callbacks: ["otel"]
|
||||||
|
|
||||||
|
guardrails:
|
||||||
|
- guardrail_name: "aporia-pre-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
|
||||||
|
mode: "post_call"
|
||||||
|
api_key: os.environ/APORIA_API_KEY_1
|
||||||
|
api_base: os.environ/APORIA_API_BASE_1
|
||||||
|
- guardrail_name: "aporia-post-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
|
||||||
|
mode: "post_call"
|
||||||
|
api_key: os.environ/APORIA_API_KEY_2
|
||||||
|
api_base: os.environ/APORIA_API_BASE_2
|
|
@ -37,6 +37,9 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b
|
||||||
|
|
||||||
requested_callback_names = []
|
requested_callback_names = []
|
||||||
|
|
||||||
|
# v1 implementation of this
|
||||||
|
if isinstance(request_guardrails, dict):
|
||||||
|
|
||||||
# get guardrail configs from `init_guardrails.py`
|
# get guardrail configs from `init_guardrails.py`
|
||||||
# for all requested guardrails -> get their associated callbacks
|
# for all requested guardrails -> get their associated callbacks
|
||||||
for _guardrail_name, should_run in request_guardrails.items():
|
for _guardrail_name, should_run in request_guardrails.items():
|
||||||
|
|
212
litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py
Normal file
212
litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py
Normal file
|
@ -0,0 +1,212 @@
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
#
|
||||||
|
# Use AporiaAI for your LLM calls
|
||||||
|
#
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import httpx
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
from litellm.litellm_core_utils.logging_utils import (
|
||||||
|
convert_litellm_response_object_to_str,
|
||||||
|
)
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
GUARDRAIL_NAME = "aporia"
|
||||||
|
|
||||||
|
|
||||||
|
class _ENTERPRISE_Aporia(CustomGuardrail):
|
||||||
|
def __init__(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
|
||||||
|
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
|
||||||
|
self.event_hook: GuardrailEventHooks
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
#### CALL HOOKS - proxy only ####
|
||||||
|
def transform_messages(self, messages: List[dict]) -> List[dict]:
|
||||||
|
supported_openai_roles = ["system", "user", "assistant"]
|
||||||
|
default_role = "other" # for unsupported roles - e.g. tool
|
||||||
|
new_messages = []
|
||||||
|
for m in messages:
|
||||||
|
if m.get("role", "") in supported_openai_roles:
|
||||||
|
new_messages.append(m)
|
||||||
|
else:
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": default_role,
|
||||||
|
**{key: value for key, value in m.items() if key != "role"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
async def prepare_aporia_request(
|
||||||
|
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||||
|
) -> dict:
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
if new_messages is not None:
|
||||||
|
data["messages"] = new_messages
|
||||||
|
if response_string is not None:
|
||||||
|
data["response"] = response_string
|
||||||
|
|
||||||
|
# Set validation target
|
||||||
|
if new_messages and response_string:
|
||||||
|
data["validation_target"] = "both"
|
||||||
|
elif new_messages:
|
||||||
|
data["validation_target"] = "prompt"
|
||||||
|
elif response_string:
|
||||||
|
data["validation_target"] = "response"
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("Aporia AI request: %s", data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def make_aporia_api_request(
|
||||||
|
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||||
|
):
|
||||||
|
data = await self.prepare_aporia_request(
|
||||||
|
new_messages=new_messages, response_string=response_string
|
||||||
|
)
|
||||||
|
|
||||||
|
_json_data = json.dumps(data)
|
||||||
|
|
||||||
|
"""
|
||||||
|
export APORIO_API_KEY=<your key>
|
||||||
|
curl https://gr-prd-trial.aporia.com/some-id \
|
||||||
|
-X POST \
|
||||||
|
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "This is a test prompt"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
'
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=self.aporia_api_base + "/validate",
|
||||||
|
data=_json_data,
|
||||||
|
headers={
|
||||||
|
"X-APORIA-API-KEY": self.aporia_api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# check if the response was flagged
|
||||||
|
_json_response = response.json()
|
||||||
|
action: str = _json_response.get(
|
||||||
|
"action"
|
||||||
|
) # possible values are modify, passthrough, block, rephrase
|
||||||
|
if action == "block":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Violated guardrail policy",
|
||||||
|
"aporia_ai_response": _json_response,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
add_guardrail_to_applied_guardrails_header,
|
||||||
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
"""
|
||||||
|
Use this for the post call moderation with Guardrails
|
||||||
|
"""
|
||||||
|
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||||
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||||
|
if response_str is not None:
|
||||||
|
await self.make_aporia_api_request(
|
||||||
|
response_string=response_str, new_messages=data.get("messages", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
add_guardrail_to_applied_guardrails_header(
|
||||||
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
|
):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
add_guardrail_to_applied_guardrails_header,
|
||||||
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||||
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
# old implementation - backwards compatibility
|
||||||
|
if (
|
||||||
|
await should_proceed_based_on_metadata(
|
||||||
|
data=data,
|
||||||
|
guardrail_name=GUARDRAIL_NAME,
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
new_messages: Optional[List[dict]] = None
|
||||||
|
if "messages" in data and isinstance(data["messages"], list):
|
||||||
|
new_messages = self.transform_messages(messages=data["messages"])
|
||||||
|
|
||||||
|
if new_messages is not None:
|
||||||
|
await self.make_aporia_api_request(new_messages=new_messages)
|
||||||
|
add_guardrail_to_applied_guardrails_header(
|
||||||
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"Aporia AI: not running guardrail. No messages in data"
|
||||||
|
)
|
||||||
|
pass
|
|
@ -1,12 +1,20 @@
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, RootModel
|
from pydantic import BaseModel, RootModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
|
||||||
|
# v2 implementation
|
||||||
|
from litellm.types.guardrails import (
|
||||||
|
Guardrail,
|
||||||
|
GuardrailItem,
|
||||||
|
GuardrailItemSpec,
|
||||||
|
LitellmParams,
|
||||||
|
guardrailConfig,
|
||||||
|
)
|
||||||
|
|
||||||
all_guardrails: List[GuardrailItem] = []
|
all_guardrails: List[GuardrailItem] = []
|
||||||
|
|
||||||
|
@ -66,3 +74,68 @@ def initialize_guardrails(
|
||||||
"error initializing guardrails {}".format(str(e))
|
"error initializing guardrails {}".format(str(e))
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Map guardrail_name: <pre_call>, <post_call>, during_call
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def init_guardrails_v2(all_guardrails: dict):
|
||||||
|
# Convert the loaded data to the TypedDict structure
|
||||||
|
guardrail_list = []
|
||||||
|
|
||||||
|
# Parse each guardrail and replace environment variables
|
||||||
|
for guardrail in all_guardrails:
|
||||||
|
|
||||||
|
# Init litellm params for guardrail
|
||||||
|
litellm_params_data = guardrail["litellm_params"]
|
||||||
|
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
||||||
|
litellm_params = LitellmParams(
|
||||||
|
guardrail=litellm_params_data["guardrail"],
|
||||||
|
mode=litellm_params_data["mode"],
|
||||||
|
api_key=litellm_params_data["api_key"],
|
||||||
|
api_base=litellm_params_data["api_base"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if litellm_params["api_key"]:
|
||||||
|
if litellm_params["api_key"].startswith("os.environ/"):
|
||||||
|
litellm_params["api_key"] = litellm.get_secret(
|
||||||
|
litellm_params["api_key"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if litellm_params["api_base"]:
|
||||||
|
if litellm_params["api_base"].startswith("os.environ/"):
|
||||||
|
litellm_params["api_base"] = litellm.get_secret(
|
||||||
|
litellm_params["api_base"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init guardrail CustomLoggerClass
|
||||||
|
if litellm_params["guardrail"] == "aporia":
|
||||||
|
from guardrail_hooks.aporia_ai import _ENTERPRISE_Aporia
|
||||||
|
|
||||||
|
_aporia_callback = _ENTERPRISE_Aporia(
|
||||||
|
api_base=litellm_params["api_base"],
|
||||||
|
api_key=litellm_params["api_key"],
|
||||||
|
guardrail_name=guardrail["guardrail_name"],
|
||||||
|
event_hook=litellm_params["mode"],
|
||||||
|
)
|
||||||
|
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||||
|
elif litellm_params["guardrail"] == "lakera":
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||||
|
_ENTERPRISE_lakeraAI_Moderation,
|
||||||
|
)
|
||||||
|
|
||||||
|
_lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
|
||||||
|
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||||
|
|
||||||
|
parsed_guardrail = Guardrail(
|
||||||
|
guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
guardrail_list.append(parsed_guardrail)
|
||||||
|
guardrail_name = guardrail["guardrail_name"]
|
||||||
|
|
||||||
|
# pretty print guardrail_list in green
|
||||||
|
print(f"\nGuardrail List:{guardrail_list}\n") # noqa
|
||||||
|
|
|
@ -1,11 +1,16 @@
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
import sys
|
||||||
from litellm.caching import DualCache
|
import traceback
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
import uuid
|
||||||
import litellm, traceback, sys, uuid
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|
||||||
|
|
||||||
class _PROXY_AzureContentSafety(
|
class _PROXY_AzureContentSafety(
|
||||||
CustomLogger
|
CustomLogger
|
||||||
|
@ -15,12 +20,12 @@ class _PROXY_AzureContentSafety(
|
||||||
def __init__(self, endpoint, api_key, thresholds=None):
|
def __init__(self, endpoint, api_key, thresholds=None):
|
||||||
try:
|
try:
|
||||||
from azure.ai.contentsafety.aio import ContentSafetyClient
|
from azure.ai.contentsafety.aio import ContentSafetyClient
|
||||||
from azure.core.credentials import AzureKeyCredential
|
|
||||||
from azure.ai.contentsafety.models import (
|
from azure.ai.contentsafety.models import (
|
||||||
TextCategory,
|
|
||||||
AnalyzeTextOptions,
|
AnalyzeTextOptions,
|
||||||
AnalyzeTextOutputType,
|
AnalyzeTextOutputType,
|
||||||
|
TextCategory,
|
||||||
)
|
)
|
||||||
|
from azure.core.credentials import AzureKeyCredential
|
||||||
from azure.core.exceptions import HttpResponseError
|
from azure.core.exceptions import HttpResponseError
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -132,6 +137,7 @@ class _PROXY_AzureContentSafety(
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response,
|
response,
|
||||||
):
|
):
|
||||||
|
|
|
@ -254,7 +254,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self, user_api_key_dict: UserAPIKeyAuth, response
|
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if isinstance(response, ModelResponse):
|
if isinstance(response, ModelResponse):
|
||||||
|
@ -287,7 +287,9 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
|
|
||||||
return response
|
return response
|
||||||
return await super().async_post_call_success_hook(
|
return await super().async_post_call_success_hook(
|
||||||
user_api_key_dict, response
|
data=data,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
response=response,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.exception(
|
verbose_proxy_logger.exception(
|
||||||
|
|
|
@ -322,6 +322,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
||||||
):
|
):
|
||||||
|
|
|
@ -316,9 +316,20 @@ async def add_litellm_data_to_request(
|
||||||
for k, v in callback_settings_obj.callback_vars.items():
|
for k, v in callback_settings_obj.callback_vars.items():
|
||||||
data[k] = v
|
data[k] = v
|
||||||
|
|
||||||
|
# Guardrails
|
||||||
|
move_guardrails_to_metadata(
|
||||||
|
data=data, _metadata_variable_name=_metadata_variable_name
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str):
|
||||||
|
if "guardrails" in data:
|
||||||
|
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
||||||
|
del data["guardrails"]
|
||||||
|
|
||||||
|
|
||||||
def add_provider_specific_headers_to_request(
|
def add_provider_specific_headers_to_request(
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
|
|
|
@ -1,50 +1,20 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-4
|
- model_name: fake-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
model_info:
|
|
||||||
access_groups: ["beta-models"]
|
|
||||||
- model_name: fireworks-llama-v3-70b-instruct
|
|
||||||
litellm_params:
|
|
||||||
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
|
||||||
api_key: "os.environ/FIREWORKS"
|
|
||||||
model_info:
|
|
||||||
access_groups: ["beta-models"]
|
|
||||||
- model_name: "*"
|
|
||||||
litellm_params:
|
|
||||||
model: "*"
|
|
||||||
- model_name: "*"
|
|
||||||
litellm_params:
|
|
||||||
model: openai/*
|
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
|
||||||
- model_name: mistral-small-latest
|
|
||||||
litellm_params:
|
|
||||||
model: mistral/mistral-small-latest
|
|
||||||
api_key: "os.environ/MISTRAL_API_KEY"
|
|
||||||
- model_name: bedrock-anthropic
|
|
||||||
litellm_params:
|
|
||||||
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
|
||||||
- model_name: gemini-1.5-pro-001
|
|
||||||
litellm_params:
|
|
||||||
model: vertex_ai_beta/gemini-1.5-pro-001
|
|
||||||
vertex_project: "adroit-crow-413218"
|
|
||||||
vertex_location: "us-central1"
|
|
||||||
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json"
|
|
||||||
# Add path to service account.json
|
|
||||||
|
|
||||||
default_vertex_config:
|
guardrails:
|
||||||
vertex_project: "adroit-crow-413218"
|
- guardrail_name: "aporia-pre-guard"
|
||||||
vertex_location: "us-central1"
|
litellm_params:
|
||||||
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
|
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
|
||||||
|
mode: "post_call"
|
||||||
|
api_key: os.environ/APORIA_API_KEY_1
|
||||||
general_settings:
|
api_base: os.environ/APORIA_API_BASE_1
|
||||||
master_key: sk-1234
|
- guardrail_name: "aporia-post-guard"
|
||||||
alerting: ["slack"]
|
litellm_params:
|
||||||
|
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
|
||||||
litellm_settings:
|
mode: "post_call"
|
||||||
fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
|
api_key: os.environ/APORIA_API_KEY_2
|
||||||
success_callback: ["langfuse", "prometheus"]
|
api_base: os.environ/APORIA_API_BASE_2
|
||||||
langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
|
|
|
@ -149,6 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
||||||
show_missing_vars_in_env,
|
show_missing_vars_in_env,
|
||||||
)
|
)
|
||||||
from litellm.proxy.common_utils.callback_utils import (
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
get_applied_guardrails_header,
|
||||||
get_remaining_tokens_and_requests_from_request_data,
|
get_remaining_tokens_and_requests_from_request_data,
|
||||||
initialize_callbacks_on_proxy,
|
initialize_callbacks_on_proxy,
|
||||||
)
|
)
|
||||||
|
@ -168,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
||||||
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
||||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
from litellm.proxy.guardrails.init_guardrails import (
|
||||||
|
init_guardrails_v2,
|
||||||
|
initialize_guardrails,
|
||||||
|
)
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
|
@ -539,6 +543,10 @@ def get_custom_headers(
|
||||||
)
|
)
|
||||||
headers.update(remaining_tokens_header)
|
headers.update(remaining_tokens_header)
|
||||||
|
|
||||||
|
applied_guardrails = get_applied_guardrails_header(request_data)
|
||||||
|
if applied_guardrails:
|
||||||
|
headers.update(applied_guardrails)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
key: value for key, value in headers.items() if value not in exclude_values
|
key: value for key, value in headers.items() if value not in exclude_values
|
||||||
|
@ -1937,6 +1945,11 @@ class ProxyConfig:
|
||||||
async_only_mode=True # only init async clients
|
async_only_mode=True # only init async clients
|
||||||
),
|
),
|
||||||
) # type:ignore
|
) # type:ignore
|
||||||
|
|
||||||
|
# Guardrail settings
|
||||||
|
guardrails_v2 = config.get("guardrails", None)
|
||||||
|
if guardrails_v2:
|
||||||
|
init_guardrails_v2(all_guardrails=guardrails_v2)
|
||||||
return router, router.get_model_list(), general_settings
|
return router, router.get_model_list(), general_settings
|
||||||
|
|
||||||
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
||||||
|
@ -3139,7 +3152,7 @@ async def chat_completion(
|
||||||
|
|
||||||
### CALL HOOKS ### - modify outgoing data
|
### CALL HOOKS ### - modify outgoing data
|
||||||
response = await proxy_logging_obj.post_call_success_hook(
|
response = await proxy_logging_obj.post_call_success_hook(
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_params = (
|
hidden_params = (
|
||||||
|
@ -3353,6 +3366,11 @@ async def completion(
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers=custom_headers,
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
### CALL HOOKS ### - modify outgoing data
|
||||||
|
response = await proxy_logging_obj.post_call_success_hook(
|
||||||
|
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||||
|
)
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
|
|
@ -432,12 +432,11 @@ class ProxyLogging:
|
||||||
"""
|
"""
|
||||||
Runs the CustomLogger's async_moderation_hook()
|
Runs the CustomLogger's async_moderation_hook()
|
||||||
"""
|
"""
|
||||||
new_data = safe_deep_copy(data)
|
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
if isinstance(callback, CustomLogger):
|
||||||
await callback.async_moderation_hook(
|
await callback.async_moderation_hook(
|
||||||
data=new_data,
|
data=data,
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
)
|
)
|
||||||
|
@ -717,6 +716,7 @@ class ProxyLogging:
|
||||||
|
|
||||||
async def post_call_success_hook(
|
async def post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
|
data: dict,
|
||||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
):
|
):
|
||||||
|
@ -738,7 +738,9 @@ class ProxyLogging:
|
||||||
_callback = callback # type: ignore
|
_callback = callback # type: ignore
|
||||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||||
await _callback.async_post_call_success_hook(
|
await _callback.async_post_call_success_hook(
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
data=data,
|
||||||
|
response=response,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Unit test for azure content safety
|
## Unit test for azure content safety
|
||||||
import sys, os, asyncio, time, random
|
import asyncio
|
||||||
from datetime import datetime
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
@ -13,11 +18,12 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router, mock_completion
|
from litellm import Router, mock_completion
|
||||||
from litellm.proxy.utils import ProxyLogging
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -177,7 +183,13 @@ async def test_strict_output_filtering_01():
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await azure_content_safety.async_post_call_success_hook(
|
await azure_content_safety.async_post_call_success_hook(
|
||||||
user_api_key_dict=UserAPIKeyAuth(), response=response
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
data={
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are an helpfull assistant"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert exc_info.value.detail["source"] == "output"
|
assert exc_info.value.detail["source"] == "output"
|
||||||
|
@ -216,7 +228,11 @@ async def test_strict_output_filtering_02():
|
||||||
)
|
)
|
||||||
|
|
||||||
await azure_content_safety.async_post_call_success_hook(
|
await azure_content_safety.async_post_call_success_hook(
|
||||||
user_api_key_dict=UserAPIKeyAuth(), response=response
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
data={
|
||||||
|
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
|
||||||
|
},
|
||||||
|
response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -251,7 +267,11 @@ async def test_loose_output_filtering_01():
|
||||||
)
|
)
|
||||||
|
|
||||||
await azure_content_safety.async_post_call_success_hook(
|
await azure_content_safety.async_post_call_success_hook(
|
||||||
user_api_key_dict=UserAPIKeyAuth(), response=response
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
data={
|
||||||
|
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
|
||||||
|
},
|
||||||
|
response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -286,5 +306,9 @@ async def test_loose_output_filtering_02():
|
||||||
)
|
)
|
||||||
|
|
||||||
await azure_content_safety.async_post_call_success_hook(
|
await azure_content_safety.async_post_call_success_hook(
|
||||||
user_api_key_dict=UserAPIKeyAuth(), response=response
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
data={
|
||||||
|
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
|
||||||
|
},
|
||||||
|
response=response,
|
||||||
)
|
)
|
||||||
|
|
|
@ -88,7 +88,11 @@ async def test_output_parsing():
|
||||||
mock_response="Hello <PERSON>! How can I assist you today?",
|
mock_response="Hello <PERSON>! How can I assist you today?",
|
||||||
)
|
)
|
||||||
new_response = await pii_masking.async_post_call_success_hook(
|
new_response = await pii_masking.async_post_call_success_hook(
|
||||||
user_api_key_dict=UserAPIKeyAuth(), response=response
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
data={
|
||||||
|
"messages": [{"role": "system", "content": "You are an helpfull assistant"}]
|
||||||
|
},
|
||||||
|
response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
@ -63,3 +63,26 @@ class GuardrailItem(BaseModel):
|
||||||
enabled_roles=enabled_roles,
|
enabled_roles=enabled_roles,
|
||||||
callback_args=callback_args,
|
callback_args=callback_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define the TypedDicts
|
||||||
|
class LitellmParams(TypedDict):
|
||||||
|
guardrail: str
|
||||||
|
mode: str
|
||||||
|
api_key: str
|
||||||
|
api_base: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class Guardrail(TypedDict):
|
||||||
|
guardrail_name: str
|
||||||
|
litellm_params: LitellmParams
|
||||||
|
|
||||||
|
|
||||||
|
class guardrailConfig(TypedDict):
|
||||||
|
guardrails: List[Guardrail]
|
||||||
|
|
||||||
|
|
||||||
|
class GuardrailEventHooks(str, Enum):
|
||||||
|
pre_call = "pre_call"
|
||||||
|
post_call = "post_call"
|
||||||
|
during_call = "during_call"
|
||||||
|
|
118
tests/otel_tests/test_guardrails.py
Normal file
118
tests/otel_tests/test_guardrails.py
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import aiohttp, openai
|
||||||
|
from openai import OpenAI, AsyncOpenAI
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
session,
|
||||||
|
key,
|
||||||
|
messages,
|
||||||
|
model: Union[str, List] = "gpt-4",
|
||||||
|
guardrails: Optional[List] = None,
|
||||||
|
):
|
||||||
|
url = "http://0.0.0.0:4000/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"guardrails": [
|
||||||
|
"aporia-post-guard",
|
||||||
|
"aporia-pre-guard",
|
||||||
|
], # default guardrails for all tests
|
||||||
|
}
|
||||||
|
|
||||||
|
if guardrails is not None:
|
||||||
|
data["guardrails"] = guardrails
|
||||||
|
|
||||||
|
print("data=", data)
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
return response_text
|
||||||
|
|
||||||
|
# response headers
|
||||||
|
response_headers = response.headers
|
||||||
|
print("response headers=", response_headers)
|
||||||
|
|
||||||
|
return await response.json(), response_headers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_guard_triggered_safe_request():
|
||||||
|
"""
|
||||||
|
- Tests a request where no content mod is triggered
|
||||||
|
- Assert that the guardrails applied are returned in the response headers
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
"sk-1234",
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[{"role": "user", "content": f"Hello what's the weather"}],
|
||||||
|
)
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
print("response=", response, "response headers", headers)
|
||||||
|
|
||||||
|
assert "x-litellm-applied-guardrails" in headers
|
||||||
|
|
||||||
|
assert (
|
||||||
|
headers["x-litellm-applied-guardrails"]
|
||||||
|
== "aporia-pre-guard,aporia-post-guard"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_guard_triggered():
|
||||||
|
"""
|
||||||
|
- Tests a request where no content mod is triggered
|
||||||
|
- Assert that the guardrails applied are returned in the response headers
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
"sk-1234",
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pytest.fail("Should have thrown an exception")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
assert "Aporia detected and blocked PII" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_llm_guard_triggered():
|
||||||
|
"""
|
||||||
|
- Tests a request where no content mod is triggered
|
||||||
|
- Assert that the guardrails applied are returned in the response headers
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
"sk-1234",
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[{"role": "user", "content": f"Hello what's the weather"}],
|
||||||
|
guardrails=[],
|
||||||
|
)
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
print("response=", response, "response headers", headers)
|
||||||
|
|
||||||
|
assert "x-litellm-applied-guardrails" not in headers
|
Loading…
Add table
Add a link
Reference in a new issue