diff --git a/.circleci/config.yml b/.circleci/config.yml
index b43a8aa64c..854bb40f71 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -317,6 +317,10 @@ jobs:
-e OPENAI_API_KEY=$OPENAI_API_KEY \
-e LITELLM_LICENSE=$LITELLM_LICENSE \
-e OTEL_EXPORTER="in_memory" \
+ -e APORIA_API_BASE_2=$APORIA_API_BASE_2 \
+ -e APORIA_API_KEY_2=$APORIA_API_KEY_2 \
+ -e APORIA_API_BASE_1=$APORIA_API_BASE_1 \
+ -e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
--name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
my-app:latest \
diff --git a/docs/my-website/docs/proxy/call_hooks.md b/docs/my-website/docs/proxy/call_hooks.md
index ce34e5ad6b..25a46609d3 100644
--- a/docs/my-website/docs/proxy/call_hooks.md
+++ b/docs/my-website/docs/proxy/call_hooks.md
@@ -47,6 +47,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
async def async_post_call_success_hook(
self,
+ data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md
index 33a899222b..94813e354b 100644
--- a/docs/my-website/docs/proxy/enterprise.md
+++ b/docs/my-website/docs/proxy/enterprise.md
@@ -36,7 +36,7 @@ Features:
- **Guardrails, PII Masking, Content Moderation**
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
- - ✅ [Prompt Injection Detection (with Aporio API)](#prompt-injection-detection---aporio-ai)
+ - ✅ [Prompt Injection Detection (with Aporia API)](#prompt-injection-detection---aporia-ai)
- ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request)
- ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
@@ -1035,9 +1035,9 @@ curl --location 'http://localhost:4000/chat/completions' \
Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call)
:::
-## Prompt Injection Detection - Aporio AI
+## Prompt Injection Detection - Aporia AI
-Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporioAI](https://www.aporia.com/)
+Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporiaAI](https://www.aporia.com/)
#### Usage
@@ -1048,11 +1048,11 @@ APORIO_API_KEY="eyJh****"
APORIO_API_BASE="https://gr..."
```
-Step 2. Add `aporio_prompt_injection` to your callbacks
+Step 2. Add `aporia_prompt_injection` to your callbacks
```yaml
litellm_settings:
- callbacks: ["aporio_prompt_injection"]
+ callbacks: ["aporia_prompt_injection"]
```
That's it, start your proxy
@@ -1081,7 +1081,7 @@ curl --location 'http://localhost:4000/chat/completions' \
"error": {
"message": {
"error": "Violated guardrail policy",
- "aporio_ai_response": {
+ "aporia_ai_response": {
"action": "block",
"revised_prompt": null,
"revised_response": "Profanity detected: Message blocked because it includes profanity. Please rephrase.",
@@ -1097,7 +1097,7 @@ curl --location 'http://localhost:4000/chat/completions' \
:::info
-Need to control AporioAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md)
+Need to control AporiaAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md)
:::
diff --git a/docs/my-website/docs/proxy/guardrails.md b/docs/my-website/docs/proxy/guardrails.md
index 698e97f9a8..451ca8ab50 100644
--- a/docs/my-website/docs/proxy/guardrails.md
+++ b/docs/my-website/docs/proxy/guardrails.md
@@ -1,18 +1,10 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
-# 🛡️ Guardrails
+# 🛡️ [Beta] Guardrails
Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy
-:::info
-
-✨ Enterprise Only Feature
-
-Schedule a meeting with us to get an Enterprise License 👉 Talk to founders [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
-
-:::
-
## Quick Start
### 1. Setup guardrails on litellm proxy config.yaml
diff --git a/docs/my-website/docs/tutorials/ab_test_llms.md b/docs/my-website/docs/tutorials/ab_test_llms.md
deleted file mode 100644
index b08e913529..0000000000
--- a/docs/my-website/docs/tutorials/ab_test_llms.md
+++ /dev/null
@@ -1,98 +0,0 @@
-import Image from '@theme/IdealImage';
-
-# Split traffic betwen GPT-4 and Llama2 in Production!
-In this tutorial, we'll walk through A/B testing between GPT-4 and Llama2 in production. We'll assume you've deployed Llama2 on Huggingface Inference Endpoints (but any of TogetherAI, Baseten, Ollama, Petals, Openrouter should work as well).
-
-
-# Relevant Resources:
-
-* 🚀 [Your production dashboard!](https://admin.litellm.ai/)
-
-
-* [Deploying models on Huggingface](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint)
-* [All supported providers on LiteLLM](https://docs.litellm.ai/docs/providers)
-
-# Code Walkthrough
-
-In production, we don't know if Llama2 is going to provide:
-* good results
-* quickly
-
-### 💡 Route 20% traffic to Llama2
-If Llama2 returns poor answers / is extremely slow, we want to roll-back this change, and use GPT-4 instead.
-
-Instead of routing 100% of our traffic to Llama2, let's **start by routing 20% traffic** to it and see how it does.
-
-```python
-## route 20% of responses to Llama2
-split_per_model = {
- "gpt-4": 0.8,
- "huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2
-}
-```
-
-## 👨💻 Complete Code
-
-### a) For Local
-If we're testing this in a script - this is what our complete code looks like.
-```python
-from litellm import completion_with_split_tests
-import os
-
-## set ENV variables
-os.environ["OPENAI_API_KEY"] = "openai key"
-os.environ["HUGGINGFACE_API_KEY"] = "huggingface key"
-
-## route 20% of responses to Llama2
-split_per_model = {
- "gpt-4": 0.8,
- "huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2
-}
-
-messages = [{ "content": "Hello, how are you?","role": "user"}]
-
-completion_with_split_tests(
- models=split_per_model,
- messages=messages,
-)
-```
-
-### b) For Production
-
-If we're in production, we don't want to keep going to code to change model/test details (prompt, split%, etc.) for our completion function and redeploying changes.
-
-LiteLLM exposes a client dashboard to do this in a UI - and instantly updates our completion function in prod.
-
-#### Relevant Code
-
-```python
-completion_with_split_tests(..., use_client=True, id="my-unique-id")
-```
-
-#### Complete Code
-
-```python
-from litellm import completion_with_split_tests
-import os
-
-## set ENV variables
-os.environ["OPENAI_API_KEY"] = "openai key"
-os.environ["HUGGINGFACE_API_KEY"] = "huggingface key"
-
-## route 20% of responses to Llama2
-split_per_model = {
- "gpt-4": 0.8,
- "huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2
-}
-
-messages = [{ "content": "Hello, how are you?","role": "user"}]
-
-completion_with_split_tests(
- models=split_per_model,
- messages=messages,
- use_client=True,
- id="my-unique-id" # Auto-create this @ https://admin.litellm.ai/
-)
-```
-
-
diff --git a/docs/my-website/docs/tutorials/litellm_proxy_aporia.md b/docs/my-website/docs/tutorials/litellm_proxy_aporia.md
new file mode 100644
index 0000000000..480c411c0f
--- /dev/null
+++ b/docs/my-website/docs/tutorials/litellm_proxy_aporia.md
@@ -0,0 +1,196 @@
+import Image from '@theme/IdealImage';
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
+# Use LiteLLM AI Gateway with Aporia Guardrails
+
+In this tutorial we will use LiteLLM Proxy with Aporia to detect PII in requests and profanity in responses
+
+## 1. Setup guardrails on Aporia
+
+### Create Aporia Projects
+
+Create two projects on [Aporia](https://guardrails.aporia.com/)
+
+1. Pre LLM API Call - Set all the policies you want to run on pre LLM API call
+2. Post LLM API Call - Set all the policies you want to run post LLM API call
+
+
+
+
+
+### Pre-Call: Detect PII
+
+Add the `PII - Prompt` to your Pre LLM API Call project
+
+
+
+### Post-Call: Detect Profanity in Responses
+
+Add the `Toxicity - Response` to your Post LLM API Call project
+
+
+
+
+## 2. Define Guardrails on your LiteLLM config.yaml
+
+- Define your guardrails under the `guardrails` section and set `pre_call_guardrails` and `post_call_guardrails`
+```yaml
+model_list:
+ - model_name: gpt-3.5-turbo
+ litellm_params:
+ model: openai/gpt-3.5-turbo
+ api_key: os.environ/OPENAI_API_KEY
+
+guardrails:
+ - guardrail_name: "aporia-pre-guard"
+ litellm_params:
+ guardrail: aporia # supported values: "aporia", "lakera"
+ mode: "during_call"
+ api_key: os.environ/APORIA_API_KEY_1
+ api_base: os.environ/APORIA_API_BASE_1
+ - guardrail_name: "aporia-post-guard"
+ litellm_params:
+ guardrail: aporia # supported values: "aporia", "lakera"
+ mode: "post_call"
+ api_key: os.environ/APORIA_API_KEY_2
+ api_base: os.environ/APORIA_API_BASE_2
+```
+
+### Supported values for `mode`
+
+- `pre_call` Run **before** LLM call, on **input**
+- `post_call` Run **after** LLM call, on **input & output**
+- `during_call` Run **during** LLM call, on **input**
+
+## 3. Start LiteLLM Gateway
+
+
+```shell
+litellm --config config.yaml --detailed_debug
+```
+
+## 4. Test request
+
+
+
+
+Expect this to fail since since `ishaan@berri.ai` in the request is PII
+
+```shell
+curl -i http://localhost:4000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
+ -d '{
+ "model": "gpt-3.5-turbo",
+ "messages": [
+ {"role": "user", "content": "hi my email is ishaan@berri.ai"}
+ ],
+ "guardrails": ["aporia-pre-guard", "aporia-post-guard"]
+ }'
+```
+
+Expected response on failure
+
+```shell
+{
+ "error": {
+ "message": {
+ "error": "Violated guardrail policy",
+ "aporia_ai_response": {
+ "action": "block",
+ "revised_prompt": null,
+ "revised_response": "Aporia detected and blocked PII",
+ "explain_log": null
+ }
+ },
+ "type": "None",
+ "param": "None",
+ "code": "400"
+ }
+}
+
+```
+
+
+
+
+
+```shell
+curl -i http://localhost:4000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
+ -d '{
+ "model": "gpt-3.5-turbo",
+ "messages": [
+ {"role": "user", "content": "hi what is the weather"}
+ ],
+ "guardrails": ["aporia-pre-guard", "aporia-post-guard"]
+ }'
+```
+
+
+
+
+
+
+## Advanced
+### Control Guardrails per Project (API Key)
+
+Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project
+- `pre_call_guardrails`: ["aporia-pre-guard"]
+- `post_call_guardrails`: ["aporia-post-guard"]
+
+**Step 1** Create Key with guardrail settings
+
+
+
+
+```shell
+curl -X POST 'http://0.0.0.0:4000/key/generate' \
+ -H 'Authorization: Bearer sk-1234' \
+ -H 'Content-Type: application/json' \
+ -D '{
+ "pre_call_guardrails": ["aporia-pre-guard"],
+ "post_call_guardrails": ["aporia"]
+ }
+ }'
+```
+
+
+
+
+```shell
+curl --location 'http://0.0.0.0:4000/key/update' \
+ --header 'Authorization: Bearer sk-1234' \
+ --header 'Content-Type: application/json' \
+ --data '{
+ "key": "sk-jNm1Zar7XfNdZXp49Z1kSQ",
+ "pre_call_guardrails": ["aporia"],
+ "post_call_guardrails": ["aporia"]
+ }
+}'
+```
+
+
+
+
+**Step 2** Test it with new key
+
+```shell
+curl --location 'http://0.0.0.0:4000/chat/completions' \
+ --header 'Authorization: Bearer sk-jNm1Zar7XfNdZXp49Z1kSQ' \
+ --header 'Content-Type: application/json' \
+ --data '{
+ "model": "gpt-3.5-turbo",
+ "messages": [
+ {
+ "role": "user",
+ "content": "my email is ishaan@berri.ai"
+ }
+ ]
+}'
+```
+
+
+
diff --git a/docs/my-website/img/aporia_post.png b/docs/my-website/img/aporia_post.png
new file mode 100644
index 0000000000..5e4d4a287b
Binary files /dev/null and b/docs/my-website/img/aporia_post.png differ
diff --git a/docs/my-website/img/aporia_pre.png b/docs/my-website/img/aporia_pre.png
new file mode 100644
index 0000000000..8df1cfdda9
Binary files /dev/null and b/docs/my-website/img/aporia_pre.png differ
diff --git a/docs/my-website/img/aporia_projs.png b/docs/my-website/img/aporia_projs.png
new file mode 100644
index 0000000000..c518fdf0bd
Binary files /dev/null and b/docs/my-website/img/aporia_projs.png differ
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index 1e550e6e77..6501ebd757 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -250,6 +250,7 @@ const sidebars = {
type: "category",
label: "Tutorials",
items: [
+ 'tutorials/litellm_proxy_aporia',
'tutorials/azure_openai',
'tutorials/instructor',
"tutorials/gradio_integration",
diff --git a/enterprise/enterprise_hooks/aporia_ai.py b/enterprise/enterprise_hooks/aporia_ai.py
new file mode 100644
index 0000000000..af909a8b51
--- /dev/null
+++ b/enterprise/enterprise_hooks/aporia_ai.py
@@ -0,0 +1,208 @@
+# +-------------------------------------------------------------+
+#
+# Use AporiaAI for your LLM calls
+#
+# +-------------------------------------------------------------+
+# Thank you users! We ❤️ you! - Krrish & Ishaan
+
+import sys
+import os
+
+sys.path.insert(
+ 0, os.path.abspath("../..")
+) # Adds the parent directory to the system path
+from typing import Optional, Literal, Union, Any
+import litellm, traceback, sys, uuid
+from litellm.caching import DualCache
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.integrations.custom_guardrail import CustomGuardrail
+from fastapi import HTTPException
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
+from litellm.litellm_core_utils.logging_utils import (
+ convert_litellm_response_object_to_str,
+)
+from typing import List
+from datetime import datetime
+import aiohttp, asyncio
+from litellm._logging import verbose_proxy_logger
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
+import httpx
+import json
+from litellm.types.guardrails import GuardrailEventHooks
+
+litellm.set_verbose = True
+
+GUARDRAIL_NAME = "aporia"
+
+
+class _ENTERPRISE_Aporia(CustomGuardrail):
+ def __init__(
+ self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
+ ):
+ self.async_handler = AsyncHTTPHandler(
+ timeout=httpx.Timeout(timeout=600.0, connect=5.0)
+ )
+ self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
+ self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
+ self.event_hook: GuardrailEventHooks
+
+ super().__init__(**kwargs)
+
+ #### CALL HOOKS - proxy only ####
+ def transform_messages(self, messages: List[dict]) -> List[dict]:
+ supported_openai_roles = ["system", "user", "assistant"]
+ default_role = "other" # for unsupported roles - e.g. tool
+ new_messages = []
+ for m in messages:
+ if m.get("role", "") in supported_openai_roles:
+ new_messages.append(m)
+ else:
+ new_messages.append(
+ {
+ "role": default_role,
+ **{key: value for key, value in m.items() if key != "role"},
+ }
+ )
+
+ return new_messages
+
+ async def prepare_aporia_request(
+ self, new_messages: List[dict], response_string: Optional[str] = None
+ ) -> dict:
+ data: dict[str, Any] = {}
+ if new_messages is not None:
+ data["messages"] = new_messages
+ if response_string is not None:
+ data["response"] = response_string
+
+ # Set validation target
+ if new_messages and response_string:
+ data["validation_target"] = "both"
+ elif new_messages:
+ data["validation_target"] = "prompt"
+ elif response_string:
+ data["validation_target"] = "response"
+
+ verbose_proxy_logger.debug("Aporia AI request: %s", data)
+ return data
+
+ async def make_aporia_api_request(
+ self, new_messages: List[dict], response_string: Optional[str] = None
+ ):
+ data = await self.prepare_aporia_request(
+ new_messages=new_messages, response_string=response_string
+ )
+
+ _json_data = json.dumps(data)
+
+ """
+ export APORIO_API_KEY=
+ curl https://gr-prd-trial.aporia.com/some-id \
+ -X POST \
+ -H "X-APORIA-API-KEY: $APORIO_API_KEY" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "messages": [
+ {
+ "role": "user",
+ "content": "This is a test prompt"
+ }
+ ],
+ }
+'
+ """
+
+ response = await self.async_handler.post(
+ url=self.aporia_api_base + "/validate",
+ data=_json_data,
+ headers={
+ "X-APORIA-API-KEY": self.aporia_api_key,
+ "Content-Type": "application/json",
+ },
+ )
+ verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
+ if response.status_code == 200:
+ # check if the response was flagged
+ _json_response = response.json()
+ action: str = _json_response.get(
+ "action"
+ ) # possible values are modify, passthrough, block, rephrase
+ if action == "block":
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Violated guardrail policy",
+ "aporia_ai_response": _json_response,
+ },
+ )
+
+ async def async_post_call_success_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ response,
+ ):
+ from litellm.proxy.common_utils.callback_utils import (
+ add_guardrail_to_applied_guardrails_header,
+ )
+ from litellm.types.guardrails import GuardrailEventHooks
+
+ """
+ Use this for the post call moderation with Guardrails
+ """
+ event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
+ if self.should_run_guardrail(data=data, event_type=event_type) is not True:
+ return
+
+ response_str: Optional[str] = convert_litellm_response_object_to_str(response)
+ if response_str is not None:
+ await self.make_aporia_api_request(
+ response_string=response_str, new_messages=data.get("messages", [])
+ )
+
+ add_guardrail_to_applied_guardrails_header(
+ request_data=data, guardrail_name=self.guardrail_name
+ )
+
+ pass
+
+ async def async_moderation_hook( ### 👈 KEY CHANGE ###
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ call_type: Literal["completion", "embeddings", "image_generation"],
+ ):
+ from litellm.proxy.common_utils.callback_utils import (
+ add_guardrail_to_applied_guardrails_header,
+ )
+ from litellm.types.guardrails import GuardrailEventHooks
+
+ event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
+ if self.should_run_guardrail(data=data, event_type=event_type) is not True:
+ return
+
+ # old implementation - backwards compatibility
+ if (
+ await should_proceed_based_on_metadata(
+ data=data,
+ guardrail_name=GUARDRAIL_NAME,
+ )
+ is False
+ ):
+ return
+
+ new_messages: Optional[List[dict]] = None
+ if "messages" in data and isinstance(data["messages"], list):
+ new_messages = self.transform_messages(messages=data["messages"])
+
+ if new_messages is not None:
+ await self.make_aporia_api_request(new_messages=new_messages)
+ add_guardrail_to_applied_guardrails_header(
+ request_data=data, guardrail_name=self.guardrail_name
+ )
+ else:
+ verbose_proxy_logger.warning(
+ "Aporia AI: not running guardrail. No messages in data"
+ )
+ pass
diff --git a/enterprise/enterprise_hooks/aporio_ai.py b/enterprise/enterprise_hooks/aporio_ai.py
deleted file mode 100644
index ce8de6eca0..0000000000
--- a/enterprise/enterprise_hooks/aporio_ai.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# +-------------------------------------------------------------+
-#
-# Use AporioAI for your LLM calls
-#
-# +-------------------------------------------------------------+
-# Thank you users! We ❤️ you! - Krrish & Ishaan
-
-import sys, os
-
-sys.path.insert(
- 0, os.path.abspath("../..")
-) # Adds the parent directory to the system path
-from typing import Optional, Literal, Union
-import litellm, traceback, sys, uuid
-from litellm.caching import DualCache
-from litellm.proxy._types import UserAPIKeyAuth
-from litellm.integrations.custom_logger import CustomLogger
-from fastapi import HTTPException
-from litellm._logging import verbose_proxy_logger
-from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
-from typing import List
-from datetime import datetime
-import aiohttp, asyncio
-from litellm._logging import verbose_proxy_logger
-from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
-import httpx
-import json
-
-litellm.set_verbose = True
-
-GUARDRAIL_NAME = "aporio"
-
-
-class _ENTERPRISE_Aporio(CustomLogger):
- def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
- self.async_handler = AsyncHTTPHandler(
- timeout=httpx.Timeout(timeout=600.0, connect=5.0)
- )
- self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"]
- self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"]
-
- #### CALL HOOKS - proxy only ####
- def transform_messages(self, messages: List[dict]) -> List[dict]:
- supported_openai_roles = ["system", "user", "assistant"]
- default_role = "other" # for unsupported roles - e.g. tool
- new_messages = []
- for m in messages:
- if m.get("role", "") in supported_openai_roles:
- new_messages.append(m)
- else:
- new_messages.append(
- {
- "role": default_role,
- **{key: value for key, value in m.items() if key != "role"},
- }
- )
-
- return new_messages
-
- async def async_moderation_hook( ### 👈 KEY CHANGE ###
- self,
- data: dict,
- user_api_key_dict: UserAPIKeyAuth,
- call_type: Literal["completion", "embeddings", "image_generation"],
- ):
-
- if (
- await should_proceed_based_on_metadata(
- data=data,
- guardrail_name=GUARDRAIL_NAME,
- )
- is False
- ):
- return
-
- new_messages: Optional[List[dict]] = None
- if "messages" in data and isinstance(data["messages"], list):
- new_messages = self.transform_messages(messages=data["messages"])
-
- if new_messages is not None:
- data = {"messages": new_messages, "validation_target": "prompt"}
-
- _json_data = json.dumps(data)
-
- """
- export APORIO_API_KEY=
- curl https://gr-prd-trial.aporia.com/some-id \
- -X POST \
- -H "X-APORIA-API-KEY: $APORIO_API_KEY" \
- -H "Content-Type: application/json" \
- -d '{
- "messages": [
- {
- "role": "user",
- "content": "This is a test prompt"
- }
- ],
- }
-'
- """
-
- response = await self.async_handler.post(
- url=self.aporio_api_base + "/validate",
- data=_json_data,
- headers={
- "X-APORIA-API-KEY": self.aporio_api_key,
- "Content-Type": "application/json",
- },
- )
- verbose_proxy_logger.debug("Aporio AI response: %s", response.text)
- if response.status_code == 200:
- # check if the response was flagged
- _json_response = response.json()
- action: str = _json_response.get(
- "action"
- ) # possible values are modify, passthrough, block, rephrase
- if action == "block":
- raise HTTPException(
- status_code=400,
- detail={
- "error": "Violated guardrail policy",
- "aporio_ai_response": _json_response,
- },
- )
diff --git a/enterprise/enterprise_hooks/banned_keywords.py b/enterprise/enterprise_hooks/banned_keywords.py
index 4d6545eb07..e282ee5ab4 100644
--- a/enterprise/enterprise_hooks/banned_keywords.py
+++ b/enterprise/enterprise_hooks/banned_keywords.py
@@ -90,6 +90,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
async def async_post_call_success_hook(
self,
+ data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py
new file mode 100644
index 0000000000..a3ac2ea863
--- /dev/null
+++ b/litellm/integrations/custom_guardrail.py
@@ -0,0 +1,32 @@
+from typing import Literal
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.types.guardrails import GuardrailEventHooks
+
+
+class CustomGuardrail(CustomLogger):
+
+ def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs):
+ self.guardrail_name = guardrail_name
+ self.event_hook: GuardrailEventHooks = event_hook
+ super().__init__(**kwargs)
+
+ def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
+ verbose_logger.debug(
+ "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s",
+ self.guardrail_name,
+ event_type,
+ self.event_hook,
+ )
+
+ metadata = data.get("metadata") or {}
+ requested_guardrails = metadata.get("guardrails") or []
+
+ if self.guardrail_name not in requested_guardrails:
+ return False
+
+ if self.event_hook != event_type:
+ return False
+
+ return True
diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py
index 98b0da25c5..47d28ab56a 100644
--- a/litellm/integrations/custom_logger.py
+++ b/litellm/integrations/custom_logger.py
@@ -122,6 +122,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_post_call_success_hook(
self,
+ data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
diff --git a/litellm/litellm_core_utils/logging_utils.py b/litellm/litellm_core_utils/logging_utils.py
index fdc9672a00..7fa1be9d8b 100644
--- a/litellm/litellm_core_utils/logging_utils.py
+++ b/litellm/litellm_core_utils/logging_utils.py
@@ -1,4 +1,12 @@
-from typing import Any
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+if TYPE_CHECKING:
+ from litellm import ModelResponse as _ModelResponse
+
+ LiteLLMModelResponse = _ModelResponse
+else:
+ LiteLLMModelResponse = Any
+
import litellm
@@ -20,3 +28,21 @@ def convert_litellm_response_object_to_dict(response_obj: Any) -> dict:
# If it's not a LiteLLM type, return the object as is
return dict(response_obj)
+
+
+def convert_litellm_response_object_to_str(
+ response_obj: Union[Any, LiteLLMModelResponse]
+) -> Optional[str]:
+ """
+ Get the string of the response object from LiteLLM
+
+ """
+ if isinstance(response_obj, litellm.ModelResponse):
+ response_str = ""
+ for choice in response_obj.choices:
+ if isinstance(choice, litellm.Choices):
+ if choice.message.content and isinstance(choice.message.content, str):
+ response_str += choice.message.content
+ return response_str
+
+ return None
diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py
index 6b000b148d..26aa28d62a 100644
--- a/litellm/proxy/common_utils/callback_utils.py
+++ b/litellm/proxy/common_utils/callback_utils.py
@@ -118,17 +118,19 @@ def initialize_callbacks_on_proxy(
**init_params
)
imported_list.append(lakera_moderations_object)
- elif isinstance(callback, str) and callback == "aporio_prompt_injection":
- from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio
+ elif isinstance(callback, str) and callback == "aporia_prompt_injection":
+ from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
+ _ENTERPRISE_Aporia,
+ )
if premium_user is not True:
raise Exception(
- "Trying to use Aporio AI Guardrail"
+ "Trying to use Aporia AI Guardrail"
+ CommonProxyErrors.not_premium_user.value
)
- aporio_guardrail_object = _ENTERPRISE_Aporio()
- imported_list.append(aporio_guardrail_object)
+ aporia_guardrail_object = _ENTERPRISE_Aporia()
+ imported_list.append(aporia_guardrail_object)
elif isinstance(callback, str) and callback == "google_text_moderation":
from enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration,
@@ -295,3 +297,21 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str,
headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens
return headers
+
+
+def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]:
+ _metadata = request_data.get("metadata", None) or {}
+ if "applied_guardrails" in _metadata:
+ return {
+ "x-litellm-applied-guardrails": ",".join(_metadata["applied_guardrails"]),
+ }
+
+ return None
+
+
+def add_guardrail_to_applied_guardrails_header(request_data: Dict, guardrail_name: str):
+ _metadata = request_data.get("metadata", None) or {}
+ if "applied_guardrails" in _metadata:
+ _metadata["applied_guardrails"].append(guardrail_name)
+ else:
+ _metadata["applied_guardrails"] = [guardrail_name]
diff --git a/litellm/proxy/custom_callbacks1.py b/litellm/proxy/custom_callbacks1.py
index 37e4a6cdb3..05028f033c 100644
--- a/litellm/proxy/custom_callbacks1.py
+++ b/litellm/proxy/custom_callbacks1.py
@@ -40,6 +40,7 @@ class MyCustomHandler(
async def async_post_call_success_hook(
self,
+ data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
diff --git a/litellm/proxy/example_config_yaml/otel_test_config.yaml b/litellm/proxy/example_config_yaml/otel_test_config.yaml
index 2e25374433..496ae1710d 100644
--- a/litellm/proxy/example_config_yaml/otel_test_config.yaml
+++ b/litellm/proxy/example_config_yaml/otel_test_config.yaml
@@ -9,3 +9,16 @@ litellm_settings:
cache: true
callbacks: ["otel"]
+guardrails:
+ - guardrail_name: "aporia-pre-guard"
+ litellm_params:
+ guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
+ mode: "post_call"
+ api_key: os.environ/APORIA_API_KEY_1
+ api_base: os.environ/APORIA_API_BASE_1
+ - guardrail_name: "aporia-post-guard"
+ litellm_params:
+ guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
+ mode: "post_call"
+ api_key: os.environ/APORIA_API_KEY_2
+ api_base: os.environ/APORIA_API_BASE_2
\ No newline at end of file
diff --git a/litellm/proxy/guardrails/guardrail_helpers.py b/litellm/proxy/guardrails/guardrail_helpers.py
index e0a5f1eb3d..a57b965c8e 100644
--- a/litellm/proxy/guardrails/guardrail_helpers.py
+++ b/litellm/proxy/guardrails/guardrail_helpers.py
@@ -37,32 +37,35 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b
requested_callback_names = []
- # get guardrail configs from `init_guardrails.py`
- # for all requested guardrails -> get their associated callbacks
- for _guardrail_name, should_run in request_guardrails.items():
- if should_run is False:
- verbose_proxy_logger.debug(
- "Guardrail %s skipped because request set to False",
- _guardrail_name,
- )
- continue
+ # v1 implementation of this
+ if isinstance(request_guardrails, dict):
- # lookup the guardrail in guardrail_name_config_map
- guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
- _guardrail_name
- ]
+ # get guardrail configs from `init_guardrails.py`
+ # for all requested guardrails -> get their associated callbacks
+ for _guardrail_name, should_run in request_guardrails.items():
+ if should_run is False:
+ verbose_proxy_logger.debug(
+ "Guardrail %s skipped because request set to False",
+ _guardrail_name,
+ )
+ continue
- guardrail_callbacks = guardrail_item.callbacks
- requested_callback_names.extend(guardrail_callbacks)
+ # lookup the guardrail in guardrail_name_config_map
+ guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
+ _guardrail_name
+ ]
- verbose_proxy_logger.debug(
- "requested_callback_names %s", requested_callback_names
- )
- if guardrail_name in requested_callback_names:
- return True
+ guardrail_callbacks = guardrail_item.callbacks
+ requested_callback_names.extend(guardrail_callbacks)
- # Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } }
- return False
+ verbose_proxy_logger.debug(
+ "requested_callback_names %s", requested_callback_names
+ )
+ if guardrail_name in requested_callback_names:
+ return True
+
+ # Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } }
+ return False
return True
diff --git a/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py
new file mode 100644
index 0000000000..29566d94db
--- /dev/null
+++ b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py
@@ -0,0 +1,212 @@
+# +-------------------------------------------------------------+
+#
+# Use AporiaAI for your LLM calls
+#
+# +-------------------------------------------------------------+
+# Thank you users! We ❤️ you! - Krrish & Ishaan
+
+import os
+import sys
+
+sys.path.insert(
+ 0, os.path.abspath("../..")
+) # Adds the parent directory to the system path
+import asyncio
+import json
+import sys
+import traceback
+import uuid
+from datetime import datetime
+from typing import Any, List, Literal, Optional, Union
+
+import aiohttp
+import httpx
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching import DualCache
+from litellm.integrations.custom_guardrail import CustomGuardrail
+from litellm.litellm_core_utils.logging_utils import (
+ convert_litellm_response_object_to_str,
+)
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
+from litellm.types.guardrails import GuardrailEventHooks
+
+litellm.set_verbose = True
+
+GUARDRAIL_NAME = "aporia"
+
+
+class _ENTERPRISE_Aporia(CustomGuardrail):
+ def __init__(
+ self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
+ ):
+ self.async_handler = AsyncHTTPHandler(
+ timeout=httpx.Timeout(timeout=600.0, connect=5.0)
+ )
+ self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
+ self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
+ self.event_hook: GuardrailEventHooks
+
+ super().__init__(**kwargs)
+
+ #### CALL HOOKS - proxy only ####
+ def transform_messages(self, messages: List[dict]) -> List[dict]:
+ supported_openai_roles = ["system", "user", "assistant"]
+ default_role = "other" # for unsupported roles - e.g. tool
+ new_messages = []
+ for m in messages:
+ if m.get("role", "") in supported_openai_roles:
+ new_messages.append(m)
+ else:
+ new_messages.append(
+ {
+ "role": default_role,
+ **{key: value for key, value in m.items() if key != "role"},
+ }
+ )
+
+ return new_messages
+
+ async def prepare_aporia_request(
+ self, new_messages: List[dict], response_string: Optional[str] = None
+ ) -> dict:
+ data: dict[str, Any] = {}
+ if new_messages is not None:
+ data["messages"] = new_messages
+ if response_string is not None:
+ data["response"] = response_string
+
+ # Set validation target
+ if new_messages and response_string:
+ data["validation_target"] = "both"
+ elif new_messages:
+ data["validation_target"] = "prompt"
+ elif response_string:
+ data["validation_target"] = "response"
+
+ verbose_proxy_logger.debug("Aporia AI request: %s", data)
+ return data
+
+ async def make_aporia_api_request(
+ self, new_messages: List[dict], response_string: Optional[str] = None
+ ):
+ data = await self.prepare_aporia_request(
+ new_messages=new_messages, response_string=response_string
+ )
+
+ _json_data = json.dumps(data)
+
+ """
+ export APORIO_API_KEY=
+ curl https://gr-prd-trial.aporia.com/some-id \
+ -X POST \
+ -H "X-APORIA-API-KEY: $APORIO_API_KEY" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "messages": [
+ {
+ "role": "user",
+ "content": "This is a test prompt"
+ }
+ ],
+ }
+'
+ """
+
+ response = await self.async_handler.post(
+ url=self.aporia_api_base + "/validate",
+ data=_json_data,
+ headers={
+ "X-APORIA-API-KEY": self.aporia_api_key,
+ "Content-Type": "application/json",
+ },
+ )
+ verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
+ if response.status_code == 200:
+ # check if the response was flagged
+ _json_response = response.json()
+ action: str = _json_response.get(
+ "action"
+ ) # possible values are modify, passthrough, block, rephrase
+ if action == "block":
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Violated guardrail policy",
+ "aporia_ai_response": _json_response,
+ },
+ )
+
+ async def async_post_call_success_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ response,
+ ):
+ from litellm.proxy.common_utils.callback_utils import (
+ add_guardrail_to_applied_guardrails_header,
+ )
+ from litellm.types.guardrails import GuardrailEventHooks
+
+ """
+ Use this for the post call moderation with Guardrails
+ """
+ event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
+ if self.should_run_guardrail(data=data, event_type=event_type) is not True:
+ return
+
+ response_str: Optional[str] = convert_litellm_response_object_to_str(response)
+ if response_str is not None:
+ await self.make_aporia_api_request(
+ response_string=response_str, new_messages=data.get("messages", [])
+ )
+
+ add_guardrail_to_applied_guardrails_header(
+ request_data=data, guardrail_name=self.guardrail_name
+ )
+
+ pass
+
+ async def async_moderation_hook( ### 👈 KEY CHANGE ###
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ call_type: Literal["completion", "embeddings", "image_generation"],
+ ):
+ from litellm.proxy.common_utils.callback_utils import (
+ add_guardrail_to_applied_guardrails_header,
+ )
+ from litellm.types.guardrails import GuardrailEventHooks
+
+ event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
+ if self.should_run_guardrail(data=data, event_type=event_type) is not True:
+ return
+
+ # old implementation - backwards compatibility
+ if (
+ await should_proceed_based_on_metadata(
+ data=data,
+ guardrail_name=GUARDRAIL_NAME,
+ )
+ is False
+ ):
+ return
+
+ new_messages: Optional[List[dict]] = None
+ if "messages" in data and isinstance(data["messages"], list):
+ new_messages = self.transform_messages(messages=data["messages"])
+
+ if new_messages is not None:
+ await self.make_aporia_api_request(new_messages=new_messages)
+ add_guardrail_to_applied_guardrails_header(
+ request_data=data, guardrail_name=self.guardrail_name
+ )
+ else:
+ verbose_proxy_logger.warning(
+ "Aporia AI: not running guardrail. No messages in data"
+ )
+ pass
diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py
index 8bf476311f..f5ed9fee84 100644
--- a/litellm/proxy/guardrails/init_guardrails.py
+++ b/litellm/proxy/guardrails/init_guardrails.py
@@ -1,12 +1,20 @@
import traceback
-from typing import Dict, List
+from typing import Dict, List, Literal
from pydantic import BaseModel, RootModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
-from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
+
+# v2 implementation
+from litellm.types.guardrails import (
+ Guardrail,
+ GuardrailItem,
+ GuardrailItemSpec,
+ LitellmParams,
+ guardrailConfig,
+)
all_guardrails: List[GuardrailItem] = []
@@ -66,3 +74,68 @@ def initialize_guardrails(
"error initializing guardrails {}".format(str(e))
)
raise e
+
+
+"""
+Map guardrail_name: , , during_call
+
+"""
+
+
+def init_guardrails_v2(all_guardrails: dict):
+ # Convert the loaded data to the TypedDict structure
+ guardrail_list = []
+
+ # Parse each guardrail and replace environment variables
+ for guardrail in all_guardrails:
+
+ # Init litellm params for guardrail
+ litellm_params_data = guardrail["litellm_params"]
+ verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
+ litellm_params = LitellmParams(
+ guardrail=litellm_params_data["guardrail"],
+ mode=litellm_params_data["mode"],
+ api_key=litellm_params_data["api_key"],
+ api_base=litellm_params_data["api_base"],
+ )
+
+ if litellm_params["api_key"]:
+ if litellm_params["api_key"].startswith("os.environ/"):
+ litellm_params["api_key"] = litellm.get_secret(
+ litellm_params["api_key"]
+ )
+
+ if litellm_params["api_base"]:
+ if litellm_params["api_base"].startswith("os.environ/"):
+ litellm_params["api_base"] = litellm.get_secret(
+ litellm_params["api_base"]
+ )
+
+ # Init guardrail CustomLoggerClass
+ if litellm_params["guardrail"] == "aporia":
+ from guardrail_hooks.aporia_ai import _ENTERPRISE_Aporia
+
+ _aporia_callback = _ENTERPRISE_Aporia(
+ api_base=litellm_params["api_base"],
+ api_key=litellm_params["api_key"],
+ guardrail_name=guardrail["guardrail_name"],
+ event_hook=litellm_params["mode"],
+ )
+ litellm.callbacks.append(_aporia_callback) # type: ignore
+ elif litellm_params["guardrail"] == "lakera":
+ from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
+ _ENTERPRISE_lakeraAI_Moderation,
+ )
+
+ _lakera_callback = _ENTERPRISE_lakeraAI_Moderation()
+ litellm.callbacks.append(_lakera_callback) # type: ignore
+
+ parsed_guardrail = Guardrail(
+ guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params
+ )
+
+ guardrail_list.append(parsed_guardrail)
+ guardrail_name = guardrail["guardrail_name"]
+
+ # pretty print guardrail_list in green
+ print(f"\nGuardrail List:{guardrail_list}\n") # noqa
diff --git a/litellm/proxy/hooks/azure_content_safety.py b/litellm/proxy/hooks/azure_content_safety.py
index 972ac99928..ccadafaf2e 100644
--- a/litellm/proxy/hooks/azure_content_safety.py
+++ b/litellm/proxy/hooks/azure_content_safety.py
@@ -1,11 +1,16 @@
-from litellm.integrations.custom_logger import CustomLogger
-from litellm.caching import DualCache
-from litellm.proxy._types import UserAPIKeyAuth
-import litellm, traceback, sys, uuid
-from fastapi import HTTPException
-from litellm._logging import verbose_proxy_logger
+import sys
+import traceback
+import uuid
from typing import Optional
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
class _PROXY_AzureContentSafety(
CustomLogger
@@ -15,12 +20,12 @@ class _PROXY_AzureContentSafety(
def __init__(self, endpoint, api_key, thresholds=None):
try:
from azure.ai.contentsafety.aio import ContentSafetyClient
- from azure.core.credentials import AzureKeyCredential
from azure.ai.contentsafety.models import (
- TextCategory,
AnalyzeTextOptions,
AnalyzeTextOutputType,
+ TextCategory,
)
+ from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
except Exception as e:
raise Exception(
@@ -132,6 +137,7 @@ class _PROXY_AzureContentSafety(
async def async_post_call_success_hook(
self,
+ data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py
index 4bf08998a4..57985e9a69 100644
--- a/litellm/proxy/hooks/dynamic_rate_limiter.py
+++ b/litellm/proxy/hooks/dynamic_rate_limiter.py
@@ -254,7 +254,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
return None
async def async_post_call_success_hook(
- self, user_api_key_dict: UserAPIKeyAuth, response
+ self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
):
try:
if isinstance(response, ModelResponse):
@@ -287,7 +287,9 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
return response
return await super().async_post_call_success_hook(
- user_api_key_dict, response
+ data=data,
+ user_api_key_dict=user_api_key_dict,
+ response=response,
)
except Exception as e:
verbose_proxy_logger.exception(
diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py
index 933d925507..6af7e3d1e5 100644
--- a/litellm/proxy/hooks/presidio_pii_masking.py
+++ b/litellm/proxy/hooks/presidio_pii_masking.py
@@ -322,6 +322,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
async def async_post_call_success_hook(
self,
+ data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
):
diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py
index 85211f943d..78f3e09492 100644
--- a/litellm/proxy/litellm_pre_call_utils.py
+++ b/litellm/proxy/litellm_pre_call_utils.py
@@ -316,9 +316,20 @@ async def add_litellm_data_to_request(
for k, v in callback_settings_obj.callback_vars.items():
data[k] = v
+ # Guardrails
+ move_guardrails_to_metadata(
+ data=data, _metadata_variable_name=_metadata_variable_name
+ )
+
return data
+def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str):
+ if "guardrails" in data:
+ data[_metadata_variable_name]["guardrails"] = data["guardrails"]
+ del data["guardrails"]
+
+
def add_provider_specific_headers_to_request(
data: dict,
headers: dict,
diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml
index e08be88aad..3c1c64292e 100644
--- a/litellm/proxy/proxy_config.yaml
+++ b/litellm/proxy/proxy_config.yaml
@@ -1,50 +1,20 @@
model_list:
- - model_name: gpt-4
+ - model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_info:
- access_groups: ["beta-models"]
- - model_name: fireworks-llama-v3-70b-instruct
- litellm_params:
- model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
- api_key: "os.environ/FIREWORKS"
- model_info:
- access_groups: ["beta-models"]
- - model_name: "*"
- litellm_params:
- model: "*"
- - model_name: "*"
- litellm_params:
- model: openai/*
- api_key: os.environ/OPENAI_API_KEY
- - model_name: mistral-small-latest
- litellm_params:
- model: mistral/mistral-small-latest
- api_key: "os.environ/MISTRAL_API_KEY"
- - model_name: bedrock-anthropic
- litellm_params:
- model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
- - model_name: gemini-1.5-pro-001
- litellm_params:
- model: vertex_ai_beta/gemini-1.5-pro-001
- vertex_project: "adroit-crow-413218"
- vertex_location: "us-central1"
- vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json"
- # Add path to service account.json
-default_vertex_config:
- vertex_project: "adroit-crow-413218"
- vertex_location: "us-central1"
- vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
-
-
-general_settings:
- master_key: sk-1234
- alerting: ["slack"]
-
-litellm_settings:
- fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}]
- success_callback: ["langfuse", "prometheus"]
- langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"]
+guardrails:
+ - guardrail_name: "aporia-pre-guard"
+ litellm_params:
+ guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
+ mode: "post_call"
+ api_key: os.environ/APORIA_API_KEY_1
+ api_base: os.environ/APORIA_API_BASE_1
+ - guardrail_name: "aporia-post-guard"
+ litellm_params:
+ guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
+ mode: "post_call"
+ api_key: os.environ/APORIA_API_KEY_2
+ api_base: os.environ/APORIA_API_BASE_2
\ No newline at end of file
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 3820f0ea33..0fbf10a2b9 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -149,6 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
show_missing_vars_in_env,
)
from litellm.proxy.common_utils.callback_utils import (
+ get_applied_guardrails_header,
get_remaining_tokens_and_requests_from_request_data,
initialize_callbacks_on_proxy,
)
@@ -168,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
)
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
-from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
+from litellm.proxy.guardrails.init_guardrails import (
+ init_guardrails_v2,
+ initialize_guardrails,
+)
from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.prompt_injection_detection import (
@@ -539,6 +543,10 @@ def get_custom_headers(
)
headers.update(remaining_tokens_header)
+ applied_guardrails = get_applied_guardrails_header(request_data)
+ if applied_guardrails:
+ headers.update(applied_guardrails)
+
try:
return {
key: value for key, value in headers.items() if value not in exclude_values
@@ -1937,6 +1945,11 @@ class ProxyConfig:
async_only_mode=True # only init async clients
),
) # type:ignore
+
+ # Guardrail settings
+ guardrails_v2 = config.get("guardrails", None)
+ if guardrails_v2:
+ init_guardrails_v2(all_guardrails=guardrails_v2)
return router, router.get_model_list(), general_settings
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
@@ -3139,7 +3152,7 @@ async def chat_completion(
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
- user_api_key_dict=user_api_key_dict, response=response
+ data=data, user_api_key_dict=user_api_key_dict, response=response
)
hidden_params = (
@@ -3353,6 +3366,11 @@ async def completion(
media_type="text/event-stream",
headers=custom_headers,
)
+ ### CALL HOOKS ### - modify outgoing data
+ response = await proxy_logging_obj.post_call_success_hook(
+ data=data, user_api_key_dict=user_api_key_dict, response=response
+ )
+
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py
index d1d17d0ef5..a2b09b4e69 100644
--- a/litellm/proxy/utils.py
+++ b/litellm/proxy/utils.py
@@ -432,12 +432,11 @@ class ProxyLogging:
"""
Runs the CustomLogger's async_moderation_hook()
"""
- new_data = safe_deep_copy(data)
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
await callback.async_moderation_hook(
- data=new_data,
+ data=data,
user_api_key_dict=user_api_key_dict,
call_type=call_type,
)
@@ -717,6 +716,7 @@ class ProxyLogging:
async def post_call_success_hook(
self,
+ data: dict,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
user_api_key_dict: UserAPIKeyAuth,
):
@@ -738,7 +738,9 @@ class ProxyLogging:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook(
- user_api_key_dict=user_api_key_dict, response=response
+ user_api_key_dict=user_api_key_dict,
+ data=data,
+ response=response,
)
except Exception as e:
raise e
diff --git a/litellm/tests/test_azure_content_safety.py b/litellm/tests/test_azure_content_safety.py
index 7b040fb252..dc80c163c3 100644
--- a/litellm/tests/test_azure_content_safety.py
+++ b/litellm/tests/test_azure_content_safety.py
@@ -1,8 +1,13 @@
# What is this?
## Unit test for azure content safety
-import sys, os, asyncio, time, random
-from datetime import datetime
+import asyncio
+import os
+import random
+import sys
+import time
import traceback
+from datetime import datetime
+
from dotenv import load_dotenv
from fastapi import HTTPException
@@ -13,11 +18,12 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
+
import litellm
from litellm import Router, mock_completion
-from litellm.proxy.utils import ProxyLogging
-from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.proxy.utils import ProxyLogging
@pytest.mark.asyncio
@@ -177,7 +183,13 @@ async def test_strict_output_filtering_01():
with pytest.raises(HTTPException) as exc_info:
await azure_content_safety.async_post_call_success_hook(
- user_api_key_dict=UserAPIKeyAuth(), response=response
+ user_api_key_dict=UserAPIKeyAuth(),
+ data={
+ "messages": [
+ {"role": "system", "content": "You are an helpfull assistant"}
+ ]
+ },
+ response=response,
)
assert exc_info.value.detail["source"] == "output"
@@ -216,7 +228,11 @@ async def test_strict_output_filtering_02():
)
await azure_content_safety.async_post_call_success_hook(
- user_api_key_dict=UserAPIKeyAuth(), response=response
+ user_api_key_dict=UserAPIKeyAuth(),
+ data={
+ "messages": [{"role": "system", "content": "You are an helpfull assistant"}]
+ },
+ response=response,
)
@@ -251,7 +267,11 @@ async def test_loose_output_filtering_01():
)
await azure_content_safety.async_post_call_success_hook(
- user_api_key_dict=UserAPIKeyAuth(), response=response
+ user_api_key_dict=UserAPIKeyAuth(),
+ data={
+ "messages": [{"role": "system", "content": "You are an helpfull assistant"}]
+ },
+ response=response,
)
@@ -286,5 +306,9 @@ async def test_loose_output_filtering_02():
)
await azure_content_safety.async_post_call_success_hook(
- user_api_key_dict=UserAPIKeyAuth(), response=response
+ user_api_key_dict=UserAPIKeyAuth(),
+ data={
+ "messages": [{"role": "system", "content": "You are an helpfull assistant"}]
+ },
+ response=response,
)
diff --git a/litellm/tests/test_presidio_masking.py b/litellm/tests/test_presidio_masking.py
index 193fcf113b..35a03ea5e3 100644
--- a/litellm/tests/test_presidio_masking.py
+++ b/litellm/tests/test_presidio_masking.py
@@ -88,7 +88,11 @@ async def test_output_parsing():
mock_response="Hello ! How can I assist you today?",
)
new_response = await pii_masking.async_post_call_success_hook(
- user_api_key_dict=UserAPIKeyAuth(), response=response
+ user_api_key_dict=UserAPIKeyAuth(),
+ data={
+ "messages": [{"role": "system", "content": "You are an helpfull assistant"}]
+ },
+ response=response,
)
assert (
diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py
index 0296d8de4a..cd9f76f171 100644
--- a/litellm/types/guardrails.py
+++ b/litellm/types/guardrails.py
@@ -1,5 +1,5 @@
from enum import Enum
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, TypedDict
from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict
@@ -63,3 +63,26 @@ class GuardrailItem(BaseModel):
enabled_roles=enabled_roles,
callback_args=callback_args,
)
+
+
+# Define the TypedDicts
+class LitellmParams(TypedDict):
+ guardrail: str
+ mode: str
+ api_key: str
+ api_base: Optional[str]
+
+
+class Guardrail(TypedDict):
+ guardrail_name: str
+ litellm_params: LitellmParams
+
+
+class guardrailConfig(TypedDict):
+ guardrails: List[Guardrail]
+
+
+class GuardrailEventHooks(str, Enum):
+ pre_call = "pre_call"
+ post_call = "post_call"
+ during_call = "during_call"
diff --git a/tests/otel_tests/test_guardrails.py b/tests/otel_tests/test_guardrails.py
new file mode 100644
index 0000000000..c48a5ba79b
--- /dev/null
+++ b/tests/otel_tests/test_guardrails.py
@@ -0,0 +1,118 @@
+import pytest
+import asyncio
+import aiohttp, openai
+from openai import OpenAI, AsyncOpenAI
+from typing import Optional, List, Union
+import uuid
+
+
+async def chat_completion(
+ session,
+ key,
+ messages,
+ model: Union[str, List] = "gpt-4",
+ guardrails: Optional[List] = None,
+):
+ url = "http://0.0.0.0:4000/chat/completions"
+ headers = {
+ "Authorization": f"Bearer {key}",
+ "Content-Type": "application/json",
+ }
+
+ data = {
+ "model": model,
+ "messages": messages,
+ "guardrails": [
+ "aporia-post-guard",
+ "aporia-pre-guard",
+ ], # default guardrails for all tests
+ }
+
+ if guardrails is not None:
+ data["guardrails"] = guardrails
+
+ print("data=", data)
+
+ async with session.post(url, headers=headers, json=data) as response:
+ status = response.status
+ response_text = await response.text()
+
+ print(response_text)
+ print()
+
+ if status != 200:
+ return response_text
+
+ # response headers
+ response_headers = response.headers
+ print("response headers=", response_headers)
+
+ return await response.json(), response_headers
+
+
+@pytest.mark.asyncio
+async def test_llm_guard_triggered_safe_request():
+ """
+ - Tests a request where no content mod is triggered
+ - Assert that the guardrails applied are returned in the response headers
+ """
+ async with aiohttp.ClientSession() as session:
+ response, headers = await chat_completion(
+ session,
+ "sk-1234",
+ model="fake-openai-endpoint",
+ messages=[{"role": "user", "content": f"Hello what's the weather"}],
+ )
+ await asyncio.sleep(3)
+
+ print("response=", response, "response headers", headers)
+
+ assert "x-litellm-applied-guardrails" in headers
+
+ assert (
+ headers["x-litellm-applied-guardrails"]
+ == "aporia-pre-guard,aporia-post-guard"
+ )
+
+
+@pytest.mark.asyncio
+async def test_llm_guard_triggered():
+ """
+ - Tests a request where no content mod is triggered
+ - Assert that the guardrails applied are returned in the response headers
+ """
+ async with aiohttp.ClientSession() as session:
+ try:
+ response, headers = await chat_completion(
+ session,
+ "sk-1234",
+ model="fake-openai-endpoint",
+ messages=[
+ {"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
+ ],
+ )
+ pytest.fail("Should have thrown an exception")
+ except Exception as e:
+ print(e)
+ assert "Aporia detected and blocked PII" in str(e)
+
+
+@pytest.mark.asyncio
+async def test_no_llm_guard_triggered():
+ """
+ - Tests a request where no content mod is triggered
+ - Assert that the guardrails applied are returned in the response headers
+ """
+ async with aiohttp.ClientSession() as session:
+ response, headers = await chat_completion(
+ session,
+ "sk-1234",
+ model="fake-openai-endpoint",
+ messages=[{"role": "user", "content": f"Hello what's the weather"}],
+ guardrails=[],
+ )
+ await asyncio.sleep(3)
+
+ print("response=", response, "response headers", headers)
+
+ assert "x-litellm-applied-guardrails" not in headers