diff --git a/.circleci/config.yml b/.circleci/config.yml index b43a8aa64c..854bb40f71 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -317,6 +317,10 @@ jobs: -e OPENAI_API_KEY=$OPENAI_API_KEY \ -e LITELLM_LICENSE=$LITELLM_LICENSE \ -e OTEL_EXPORTER="in_memory" \ + -e APORIA_API_BASE_2=$APORIA_API_BASE_2 \ + -e APORIA_API_KEY_2=$APORIA_API_KEY_2 \ + -e APORIA_API_BASE_1=$APORIA_API_BASE_1 \ + -e APORIA_API_KEY_1=$APORIA_API_KEY_1 \ --name my-app \ -v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \ my-app:latest \ diff --git a/docs/my-website/docs/proxy/call_hooks.md b/docs/my-website/docs/proxy/call_hooks.md index ce34e5ad6b..25a46609d3 100644 --- a/docs/my-website/docs/proxy/call_hooks.md +++ b/docs/my-website/docs/proxy/call_hooks.md @@ -47,6 +47,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md index 33a899222b..94813e354b 100644 --- a/docs/my-website/docs/proxy/enterprise.md +++ b/docs/my-website/docs/proxy/enterprise.md @@ -36,7 +36,7 @@ Features: - **Guardrails, PII Masking, Content Moderation** - ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation) - ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai) - - ✅ [Prompt Injection Detection (with Aporio API)](#prompt-injection-detection---aporio-ai) + - ✅ [Prompt Injection Detection (with Aporia API)](#prompt-injection-detection---aporia-ai) - ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request) - ✅ Reject calls from Blocked User list - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) @@ -1035,9 +1035,9 @@ curl --location 'http://localhost:4000/chat/completions' \ Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call) ::: -## Prompt Injection Detection - Aporio AI +## Prompt Injection Detection - Aporia AI -Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporioAI](https://www.aporia.com/) +Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporiaAI](https://www.aporia.com/) #### Usage @@ -1048,11 +1048,11 @@ APORIO_API_KEY="eyJh****" APORIO_API_BASE="https://gr..." ``` -Step 2. Add `aporio_prompt_injection` to your callbacks +Step 2. Add `aporia_prompt_injection` to your callbacks ```yaml litellm_settings: - callbacks: ["aporio_prompt_injection"] + callbacks: ["aporia_prompt_injection"] ``` That's it, start your proxy @@ -1081,7 +1081,7 @@ curl --location 'http://localhost:4000/chat/completions' \ "error": { "message": { "error": "Violated guardrail policy", - "aporio_ai_response": { + "aporia_ai_response": { "action": "block", "revised_prompt": null, "revised_response": "Profanity detected: Message blocked because it includes profanity. Please rephrase.", @@ -1097,7 +1097,7 @@ curl --location 'http://localhost:4000/chat/completions' \ :::info -Need to control AporioAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md) +Need to control AporiaAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md) ::: diff --git a/docs/my-website/docs/proxy/guardrails.md b/docs/my-website/docs/proxy/guardrails.md index 698e97f9a8..451ca8ab50 100644 --- a/docs/my-website/docs/proxy/guardrails.md +++ b/docs/my-website/docs/proxy/guardrails.md @@ -1,18 +1,10 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# 🛡️ Guardrails +# 🛡️ [Beta] Guardrails Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy -:::info - -✨ Enterprise Only Feature - -Schedule a meeting with us to get an Enterprise License 👉 Talk to founders [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) - -::: - ## Quick Start ### 1. Setup guardrails on litellm proxy config.yaml diff --git a/docs/my-website/docs/tutorials/ab_test_llms.md b/docs/my-website/docs/tutorials/ab_test_llms.md deleted file mode 100644 index b08e913529..0000000000 --- a/docs/my-website/docs/tutorials/ab_test_llms.md +++ /dev/null @@ -1,98 +0,0 @@ -import Image from '@theme/IdealImage'; - -# Split traffic betwen GPT-4 and Llama2 in Production! -In this tutorial, we'll walk through A/B testing between GPT-4 and Llama2 in production. We'll assume you've deployed Llama2 on Huggingface Inference Endpoints (but any of TogetherAI, Baseten, Ollama, Petals, Openrouter should work as well). - - -# Relevant Resources: - -* 🚀 [Your production dashboard!](https://admin.litellm.ai/) - - -* [Deploying models on Huggingface](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint) -* [All supported providers on LiteLLM](https://docs.litellm.ai/docs/providers) - -# Code Walkthrough - -In production, we don't know if Llama2 is going to provide: -* good results -* quickly - -### 💡 Route 20% traffic to Llama2 -If Llama2 returns poor answers / is extremely slow, we want to roll-back this change, and use GPT-4 instead. - -Instead of routing 100% of our traffic to Llama2, let's **start by routing 20% traffic** to it and see how it does. - -```python -## route 20% of responses to Llama2 -split_per_model = { - "gpt-4": 0.8, - "huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2 -} -``` - -## 👨‍💻 Complete Code - -### a) For Local -If we're testing this in a script - this is what our complete code looks like. -```python -from litellm import completion_with_split_tests -import os - -## set ENV variables -os.environ["OPENAI_API_KEY"] = "openai key" -os.environ["HUGGINGFACE_API_KEY"] = "huggingface key" - -## route 20% of responses to Llama2 -split_per_model = { - "gpt-4": 0.8, - "huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2 -} - -messages = [{ "content": "Hello, how are you?","role": "user"}] - -completion_with_split_tests( - models=split_per_model, - messages=messages, -) -``` - -### b) For Production - -If we're in production, we don't want to keep going to code to change model/test details (prompt, split%, etc.) for our completion function and redeploying changes. - -LiteLLM exposes a client dashboard to do this in a UI - and instantly updates our completion function in prod. - -#### Relevant Code - -```python -completion_with_split_tests(..., use_client=True, id="my-unique-id") -``` - -#### Complete Code - -```python -from litellm import completion_with_split_tests -import os - -## set ENV variables -os.environ["OPENAI_API_KEY"] = "openai key" -os.environ["HUGGINGFACE_API_KEY"] = "huggingface key" - -## route 20% of responses to Llama2 -split_per_model = { - "gpt-4": 0.8, - "huggingface/https://my-unique-endpoint.us-east-1.aws.endpoints.huggingface.cloud": 0.2 -} - -messages = [{ "content": "Hello, how are you?","role": "user"}] - -completion_with_split_tests( - models=split_per_model, - messages=messages, - use_client=True, - id="my-unique-id" # Auto-create this @ https://admin.litellm.ai/ -) -``` - - diff --git a/docs/my-website/docs/tutorials/litellm_proxy_aporia.md b/docs/my-website/docs/tutorials/litellm_proxy_aporia.md new file mode 100644 index 0000000000..480c411c0f --- /dev/null +++ b/docs/my-website/docs/tutorials/litellm_proxy_aporia.md @@ -0,0 +1,196 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Use LiteLLM AI Gateway with Aporia Guardrails + +In this tutorial we will use LiteLLM Proxy with Aporia to detect PII in requests and profanity in responses + +## 1. Setup guardrails on Aporia + +### Create Aporia Projects + +Create two projects on [Aporia](https://guardrails.aporia.com/) + +1. Pre LLM API Call - Set all the policies you want to run on pre LLM API call +2. Post LLM API Call - Set all the policies you want to run post LLM API call + + + + + +### Pre-Call: Detect PII + +Add the `PII - Prompt` to your Pre LLM API Call project + + + +### Post-Call: Detect Profanity in Responses + +Add the `Toxicity - Response` to your Post LLM API Call project + + + + +## 2. Define Guardrails on your LiteLLM config.yaml + +- Define your guardrails under the `guardrails` section and set `pre_call_guardrails` and `post_call_guardrails` +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: openai/gpt-3.5-turbo + api_key: os.environ/OPENAI_API_KEY + +guardrails: + - guardrail_name: "aporia-pre-guard" + litellm_params: + guardrail: aporia # supported values: "aporia", "lakera" + mode: "during_call" + api_key: os.environ/APORIA_API_KEY_1 + api_base: os.environ/APORIA_API_BASE_1 + - guardrail_name: "aporia-post-guard" + litellm_params: + guardrail: aporia # supported values: "aporia", "lakera" + mode: "post_call" + api_key: os.environ/APORIA_API_KEY_2 + api_base: os.environ/APORIA_API_BASE_2 +``` + +### Supported values for `mode` + +- `pre_call` Run **before** LLM call, on **input** +- `post_call` Run **after** LLM call, on **input & output** +- `during_call` Run **during** LLM call, on **input** + +## 3. Start LiteLLM Gateway + + +```shell +litellm --config config.yaml --detailed_debug +``` + +## 4. Test request + + + + +Expect this to fail since since `ishaan@berri.ai` in the request is PII + +```shell +curl -i http://localhost:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \ + -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "hi my email is ishaan@berri.ai"} + ], + "guardrails": ["aporia-pre-guard", "aporia-post-guard"] + }' +``` + +Expected response on failure + +```shell +{ + "error": { + "message": { + "error": "Violated guardrail policy", + "aporia_ai_response": { + "action": "block", + "revised_prompt": null, + "revised_response": "Aporia detected and blocked PII", + "explain_log": null + } + }, + "type": "None", + "param": "None", + "code": "400" + } +} + +``` + + + + + +```shell +curl -i http://localhost:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \ + -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "hi what is the weather"} + ], + "guardrails": ["aporia-pre-guard", "aporia-post-guard"] + }' +``` + + + + + + +## Advanced +### Control Guardrails per Project (API Key) + +Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project +- `pre_call_guardrails`: ["aporia-pre-guard"] +- `post_call_guardrails`: ["aporia-post-guard"] + +**Step 1** Create Key with guardrail settings + + + + +```shell +curl -X POST 'http://0.0.0.0:4000/key/generate' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -D '{ + "pre_call_guardrails": ["aporia-pre-guard"], + "post_call_guardrails": ["aporia"] + } + }' +``` + + + + +```shell +curl --location 'http://0.0.0.0:4000/key/update' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "key": "sk-jNm1Zar7XfNdZXp49Z1kSQ", + "pre_call_guardrails": ["aporia"], + "post_call_guardrails": ["aporia"] + } +}' +``` + + + + +**Step 2** Test it with new key + +```shell +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-jNm1Zar7XfNdZXp49Z1kSQ' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "my email is ishaan@berri.ai" + } + ] +}' +``` + + + diff --git a/docs/my-website/img/aporia_post.png b/docs/my-website/img/aporia_post.png new file mode 100644 index 0000000000..5e4d4a287b Binary files /dev/null and b/docs/my-website/img/aporia_post.png differ diff --git a/docs/my-website/img/aporia_pre.png b/docs/my-website/img/aporia_pre.png new file mode 100644 index 0000000000..8df1cfdda9 Binary files /dev/null and b/docs/my-website/img/aporia_pre.png differ diff --git a/docs/my-website/img/aporia_projs.png b/docs/my-website/img/aporia_projs.png new file mode 100644 index 0000000000..c518fdf0bd Binary files /dev/null and b/docs/my-website/img/aporia_projs.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 1e550e6e77..6501ebd757 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -250,6 +250,7 @@ const sidebars = { type: "category", label: "Tutorials", items: [ + 'tutorials/litellm_proxy_aporia', 'tutorials/azure_openai', 'tutorials/instructor', "tutorials/gradio_integration", diff --git a/enterprise/enterprise_hooks/aporia_ai.py b/enterprise/enterprise_hooks/aporia_ai.py new file mode 100644 index 0000000000..af909a8b51 --- /dev/null +++ b/enterprise/enterprise_hooks/aporia_ai.py @@ -0,0 +1,208 @@ +# +-------------------------------------------------------------+ +# +# Use AporiaAI for your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import sys +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from typing import Optional, Literal, Union, Any +import litellm, traceback, sys, uuid +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_guardrail import CustomGuardrail +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.litellm_core_utils.logging_utils import ( + convert_litellm_response_object_to_str, +) +from typing import List +from datetime import datetime +import aiohttp, asyncio +from litellm._logging import verbose_proxy_logger +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +import httpx +import json +from litellm.types.guardrails import GuardrailEventHooks + +litellm.set_verbose = True + +GUARDRAIL_NAME = "aporia" + + +class _ENTERPRISE_Aporia(CustomGuardrail): + def __init__( + self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs + ): + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] + self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] + self.event_hook: GuardrailEventHooks + + super().__init__(**kwargs) + + #### CALL HOOKS - proxy only #### + def transform_messages(self, messages: List[dict]) -> List[dict]: + supported_openai_roles = ["system", "user", "assistant"] + default_role = "other" # for unsupported roles - e.g. tool + new_messages = [] + for m in messages: + if m.get("role", "") in supported_openai_roles: + new_messages.append(m) + else: + new_messages.append( + { + "role": default_role, + **{key: value for key, value in m.items() if key != "role"}, + } + ) + + return new_messages + + async def prepare_aporia_request( + self, new_messages: List[dict], response_string: Optional[str] = None + ) -> dict: + data: dict[str, Any] = {} + if new_messages is not None: + data["messages"] = new_messages + if response_string is not None: + data["response"] = response_string + + # Set validation target + if new_messages and response_string: + data["validation_target"] = "both" + elif new_messages: + data["validation_target"] = "prompt" + elif response_string: + data["validation_target"] = "response" + + verbose_proxy_logger.debug("Aporia AI request: %s", data) + return data + + async def make_aporia_api_request( + self, new_messages: List[dict], response_string: Optional[str] = None + ): + data = await self.prepare_aporia_request( + new_messages=new_messages, response_string=response_string + ) + + _json_data = json.dumps(data) + + """ + export APORIO_API_KEY= + curl https://gr-prd-trial.aporia.com/some-id \ + -X POST \ + -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + { + "role": "user", + "content": "This is a test prompt" + } + ], + } +' + """ + + response = await self.async_handler.post( + url=self.aporia_api_base + "/validate", + data=_json_data, + headers={ + "X-APORIA-API-KEY": self.aporia_api_key, + "Content-Type": "application/json", + }, + ) + verbose_proxy_logger.debug("Aporia AI response: %s", response.text) + if response.status_code == 200: + # check if the response was flagged + _json_response = response.json() + action: str = _json_response.get( + "action" + ) # possible values are modify, passthrough, block, rephrase + if action == "block": + raise HTTPException( + status_code=400, + detail={ + "error": "Violated guardrail policy", + "aporia_ai_response": _json_response, + }, + ) + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks + + """ + Use this for the post call moderation with Guardrails + """ + event_type: GuardrailEventHooks = GuardrailEventHooks.post_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + response_str: Optional[str] = convert_litellm_response_object_to_str(response) + if response_str is not None: + await self.make_aporia_api_request( + response_string=response_str, new_messages=data.get("messages", []) + ) + + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + + pass + + async def async_moderation_hook( ### 👈 KEY CHANGE ### + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks + + event_type: GuardrailEventHooks = GuardrailEventHooks.during_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + # old implementation - backwards compatibility + if ( + await should_proceed_based_on_metadata( + data=data, + guardrail_name=GUARDRAIL_NAME, + ) + is False + ): + return + + new_messages: Optional[List[dict]] = None + if "messages" in data and isinstance(data["messages"], list): + new_messages = self.transform_messages(messages=data["messages"]) + + if new_messages is not None: + await self.make_aporia_api_request(new_messages=new_messages) + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + else: + verbose_proxy_logger.warning( + "Aporia AI: not running guardrail. No messages in data" + ) + pass diff --git a/enterprise/enterprise_hooks/aporio_ai.py b/enterprise/enterprise_hooks/aporio_ai.py deleted file mode 100644 index ce8de6eca0..0000000000 --- a/enterprise/enterprise_hooks/aporio_ai.py +++ /dev/null @@ -1,124 +0,0 @@ -# +-------------------------------------------------------------+ -# -# Use AporioAI for your LLM calls -# -# +-------------------------------------------------------------+ -# Thank you users! We ❤️ you! - Krrish & Ishaan - -import sys, os - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -from typing import Optional, Literal, Union -import litellm, traceback, sys, uuid -from litellm.caching import DualCache -from litellm.proxy._types import UserAPIKeyAuth -from litellm.integrations.custom_logger import CustomLogger -from fastapi import HTTPException -from litellm._logging import verbose_proxy_logger -from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata -from typing import List -from datetime import datetime -import aiohttp, asyncio -from litellm._logging import verbose_proxy_logger -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler -import httpx -import json - -litellm.set_verbose = True - -GUARDRAIL_NAME = "aporio" - - -class _ENTERPRISE_Aporio(CustomLogger): - def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None): - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) - ) - self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"] - self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"] - - #### CALL HOOKS - proxy only #### - def transform_messages(self, messages: List[dict]) -> List[dict]: - supported_openai_roles = ["system", "user", "assistant"] - default_role = "other" # for unsupported roles - e.g. tool - new_messages = [] - for m in messages: - if m.get("role", "") in supported_openai_roles: - new_messages.append(m) - else: - new_messages.append( - { - "role": default_role, - **{key: value for key, value in m.items() if key != "role"}, - } - ) - - return new_messages - - async def async_moderation_hook( ### 👈 KEY CHANGE ### - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - call_type: Literal["completion", "embeddings", "image_generation"], - ): - - if ( - await should_proceed_based_on_metadata( - data=data, - guardrail_name=GUARDRAIL_NAME, - ) - is False - ): - return - - new_messages: Optional[List[dict]] = None - if "messages" in data and isinstance(data["messages"], list): - new_messages = self.transform_messages(messages=data["messages"]) - - if new_messages is not None: - data = {"messages": new_messages, "validation_target": "prompt"} - - _json_data = json.dumps(data) - - """ - export APORIO_API_KEY= - curl https://gr-prd-trial.aporia.com/some-id \ - -X POST \ - -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ - -H "Content-Type: application/json" \ - -d '{ - "messages": [ - { - "role": "user", - "content": "This is a test prompt" - } - ], - } -' - """ - - response = await self.async_handler.post( - url=self.aporio_api_base + "/validate", - data=_json_data, - headers={ - "X-APORIA-API-KEY": self.aporio_api_key, - "Content-Type": "application/json", - }, - ) - verbose_proxy_logger.debug("Aporio AI response: %s", response.text) - if response.status_code == 200: - # check if the response was flagged - _json_response = response.json() - action: str = _json_response.get( - "action" - ) # possible values are modify, passthrough, block, rephrase - if action == "block": - raise HTTPException( - status_code=400, - detail={ - "error": "Violated guardrail policy", - "aporio_ai_response": _json_response, - }, - ) diff --git a/enterprise/enterprise_hooks/banned_keywords.py b/enterprise/enterprise_hooks/banned_keywords.py index 4d6545eb07..e282ee5ab4 100644 --- a/enterprise/enterprise_hooks/banned_keywords.py +++ b/enterprise/enterprise_hooks/banned_keywords.py @@ -90,6 +90,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger): async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py new file mode 100644 index 0000000000..a3ac2ea863 --- /dev/null +++ b/litellm/integrations/custom_guardrail.py @@ -0,0 +1,32 @@ +from typing import Literal + +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.guardrails import GuardrailEventHooks + + +class CustomGuardrail(CustomLogger): + + def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs): + self.guardrail_name = guardrail_name + self.event_hook: GuardrailEventHooks = event_hook + super().__init__(**kwargs) + + def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: + verbose_logger.debug( + "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s", + self.guardrail_name, + event_type, + self.event_hook, + ) + + metadata = data.get("metadata") or {} + requested_guardrails = metadata.get("guardrails") or [] + + if self.guardrail_name not in requested_guardrails: + return False + + if self.event_hook != event_type: + return False + + return True diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 98b0da25c5..47d28ab56a 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -122,6 +122,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/litellm/litellm_core_utils/logging_utils.py b/litellm/litellm_core_utils/logging_utils.py index fdc9672a00..7fa1be9d8b 100644 --- a/litellm/litellm_core_utils/logging_utils.py +++ b/litellm/litellm_core_utils/logging_utils.py @@ -1,4 +1,12 @@ -from typing import Any +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: + from litellm import ModelResponse as _ModelResponse + + LiteLLMModelResponse = _ModelResponse +else: + LiteLLMModelResponse = Any + import litellm @@ -20,3 +28,21 @@ def convert_litellm_response_object_to_dict(response_obj: Any) -> dict: # If it's not a LiteLLM type, return the object as is return dict(response_obj) + + +def convert_litellm_response_object_to_str( + response_obj: Union[Any, LiteLLMModelResponse] +) -> Optional[str]: + """ + Get the string of the response object from LiteLLM + + """ + if isinstance(response_obj, litellm.ModelResponse): + response_str = "" + for choice in response_obj.choices: + if isinstance(choice, litellm.Choices): + if choice.message.content and isinstance(choice.message.content, str): + response_str += choice.message.content + return response_str + + return None diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 6b000b148d..26aa28d62a 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -118,17 +118,19 @@ def initialize_callbacks_on_proxy( **init_params ) imported_list.append(lakera_moderations_object) - elif isinstance(callback, str) and callback == "aporio_prompt_injection": - from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio + elif isinstance(callback, str) and callback == "aporia_prompt_injection": + from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import ( + _ENTERPRISE_Aporia, + ) if premium_user is not True: raise Exception( - "Trying to use Aporio AI Guardrail" + "Trying to use Aporia AI Guardrail" + CommonProxyErrors.not_premium_user.value ) - aporio_guardrail_object = _ENTERPRISE_Aporio() - imported_list.append(aporio_guardrail_object) + aporia_guardrail_object = _ENTERPRISE_Aporia() + imported_list.append(aporia_guardrail_object) elif isinstance(callback, str) and callback == "google_text_moderation": from enterprise.enterprise_hooks.google_text_moderation import ( _ENTERPRISE_GoogleTextModeration, @@ -295,3 +297,21 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens return headers + + +def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]: + _metadata = request_data.get("metadata", None) or {} + if "applied_guardrails" in _metadata: + return { + "x-litellm-applied-guardrails": ",".join(_metadata["applied_guardrails"]), + } + + return None + + +def add_guardrail_to_applied_guardrails_header(request_data: Dict, guardrail_name: str): + _metadata = request_data.get("metadata", None) or {} + if "applied_guardrails" in _metadata: + _metadata["applied_guardrails"].append(guardrail_name) + else: + _metadata["applied_guardrails"] = [guardrail_name] diff --git a/litellm/proxy/custom_callbacks1.py b/litellm/proxy/custom_callbacks1.py index 37e4a6cdb3..05028f033c 100644 --- a/litellm/proxy/custom_callbacks1.py +++ b/litellm/proxy/custom_callbacks1.py @@ -40,6 +40,7 @@ class MyCustomHandler( async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/litellm/proxy/example_config_yaml/otel_test_config.yaml b/litellm/proxy/example_config_yaml/otel_test_config.yaml index 2e25374433..496ae1710d 100644 --- a/litellm/proxy/example_config_yaml/otel_test_config.yaml +++ b/litellm/proxy/example_config_yaml/otel_test_config.yaml @@ -9,3 +9,16 @@ litellm_settings: cache: true callbacks: ["otel"] +guardrails: + - guardrail_name: "aporia-pre-guard" + litellm_params: + guardrail: aporia # supported values: "aporia", "bedrock", "lakera" + mode: "post_call" + api_key: os.environ/APORIA_API_KEY_1 + api_base: os.environ/APORIA_API_BASE_1 + - guardrail_name: "aporia-post-guard" + litellm_params: + guardrail: aporia # supported values: "aporia", "bedrock", "lakera" + mode: "post_call" + api_key: os.environ/APORIA_API_KEY_2 + api_base: os.environ/APORIA_API_BASE_2 \ No newline at end of file diff --git a/litellm/proxy/guardrails/guardrail_helpers.py b/litellm/proxy/guardrails/guardrail_helpers.py index e0a5f1eb3d..a57b965c8e 100644 --- a/litellm/proxy/guardrails/guardrail_helpers.py +++ b/litellm/proxy/guardrails/guardrail_helpers.py @@ -37,32 +37,35 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b requested_callback_names = [] - # get guardrail configs from `init_guardrails.py` - # for all requested guardrails -> get their associated callbacks - for _guardrail_name, should_run in request_guardrails.items(): - if should_run is False: - verbose_proxy_logger.debug( - "Guardrail %s skipped because request set to False", - _guardrail_name, - ) - continue + # v1 implementation of this + if isinstance(request_guardrails, dict): - # lookup the guardrail in guardrail_name_config_map - guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[ - _guardrail_name - ] + # get guardrail configs from `init_guardrails.py` + # for all requested guardrails -> get their associated callbacks + for _guardrail_name, should_run in request_guardrails.items(): + if should_run is False: + verbose_proxy_logger.debug( + "Guardrail %s skipped because request set to False", + _guardrail_name, + ) + continue - guardrail_callbacks = guardrail_item.callbacks - requested_callback_names.extend(guardrail_callbacks) + # lookup the guardrail in guardrail_name_config_map + guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[ + _guardrail_name + ] - verbose_proxy_logger.debug( - "requested_callback_names %s", requested_callback_names - ) - if guardrail_name in requested_callback_names: - return True + guardrail_callbacks = guardrail_item.callbacks + requested_callback_names.extend(guardrail_callbacks) - # Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } } - return False + verbose_proxy_logger.debug( + "requested_callback_names %s", requested_callback_names + ) + if guardrail_name in requested_callback_names: + return True + + # Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } } + return False return True diff --git a/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py new file mode 100644 index 0000000000..29566d94db --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py @@ -0,0 +1,212 @@ +# +-------------------------------------------------------------+ +# +# Use AporiaAI for your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import json +import sys +import traceback +import uuid +from datetime import datetime +from typing import Any, List, Literal, Optional, Union + +import aiohttp +import httpx +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.litellm_core_utils.logging_utils import ( + convert_litellm_response_object_to_str, +) +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.types.guardrails import GuardrailEventHooks + +litellm.set_verbose = True + +GUARDRAIL_NAME = "aporia" + + +class _ENTERPRISE_Aporia(CustomGuardrail): + def __init__( + self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs + ): + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] + self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] + self.event_hook: GuardrailEventHooks + + super().__init__(**kwargs) + + #### CALL HOOKS - proxy only #### + def transform_messages(self, messages: List[dict]) -> List[dict]: + supported_openai_roles = ["system", "user", "assistant"] + default_role = "other" # for unsupported roles - e.g. tool + new_messages = [] + for m in messages: + if m.get("role", "") in supported_openai_roles: + new_messages.append(m) + else: + new_messages.append( + { + "role": default_role, + **{key: value for key, value in m.items() if key != "role"}, + } + ) + + return new_messages + + async def prepare_aporia_request( + self, new_messages: List[dict], response_string: Optional[str] = None + ) -> dict: + data: dict[str, Any] = {} + if new_messages is not None: + data["messages"] = new_messages + if response_string is not None: + data["response"] = response_string + + # Set validation target + if new_messages and response_string: + data["validation_target"] = "both" + elif new_messages: + data["validation_target"] = "prompt" + elif response_string: + data["validation_target"] = "response" + + verbose_proxy_logger.debug("Aporia AI request: %s", data) + return data + + async def make_aporia_api_request( + self, new_messages: List[dict], response_string: Optional[str] = None + ): + data = await self.prepare_aporia_request( + new_messages=new_messages, response_string=response_string + ) + + _json_data = json.dumps(data) + + """ + export APORIO_API_KEY= + curl https://gr-prd-trial.aporia.com/some-id \ + -X POST \ + -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + { + "role": "user", + "content": "This is a test prompt" + } + ], + } +' + """ + + response = await self.async_handler.post( + url=self.aporia_api_base + "/validate", + data=_json_data, + headers={ + "X-APORIA-API-KEY": self.aporia_api_key, + "Content-Type": "application/json", + }, + ) + verbose_proxy_logger.debug("Aporia AI response: %s", response.text) + if response.status_code == 200: + # check if the response was flagged + _json_response = response.json() + action: str = _json_response.get( + "action" + ) # possible values are modify, passthrough, block, rephrase + if action == "block": + raise HTTPException( + status_code=400, + detail={ + "error": "Violated guardrail policy", + "aporia_ai_response": _json_response, + }, + ) + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks + + """ + Use this for the post call moderation with Guardrails + """ + event_type: GuardrailEventHooks = GuardrailEventHooks.post_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + response_str: Optional[str] = convert_litellm_response_object_to_str(response) + if response_str is not None: + await self.make_aporia_api_request( + response_string=response_str, new_messages=data.get("messages", []) + ) + + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + + pass + + async def async_moderation_hook( ### 👈 KEY CHANGE ### + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks + + event_type: GuardrailEventHooks = GuardrailEventHooks.during_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + # old implementation - backwards compatibility + if ( + await should_proceed_based_on_metadata( + data=data, + guardrail_name=GUARDRAIL_NAME, + ) + is False + ): + return + + new_messages: Optional[List[dict]] = None + if "messages" in data and isinstance(data["messages"], list): + new_messages = self.transform_messages(messages=data["messages"]) + + if new_messages is not None: + await self.make_aporia_api_request(new_messages=new_messages) + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + else: + verbose_proxy_logger.warning( + "Aporia AI: not running guardrail. No messages in data" + ) + pass diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 8bf476311f..f5ed9fee84 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -1,12 +1,20 @@ import traceback -from typing import Dict, List +from typing import Dict, List, Literal from pydantic import BaseModel, RootModel import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy -from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec + +# v2 implementation +from litellm.types.guardrails import ( + Guardrail, + GuardrailItem, + GuardrailItemSpec, + LitellmParams, + guardrailConfig, +) all_guardrails: List[GuardrailItem] = [] @@ -66,3 +74,68 @@ def initialize_guardrails( "error initializing guardrails {}".format(str(e)) ) raise e + + +""" +Map guardrail_name: , , during_call + +""" + + +def init_guardrails_v2(all_guardrails: dict): + # Convert the loaded data to the TypedDict structure + guardrail_list = [] + + # Parse each guardrail and replace environment variables + for guardrail in all_guardrails: + + # Init litellm params for guardrail + litellm_params_data = guardrail["litellm_params"] + verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data) + litellm_params = LitellmParams( + guardrail=litellm_params_data["guardrail"], + mode=litellm_params_data["mode"], + api_key=litellm_params_data["api_key"], + api_base=litellm_params_data["api_base"], + ) + + if litellm_params["api_key"]: + if litellm_params["api_key"].startswith("os.environ/"): + litellm_params["api_key"] = litellm.get_secret( + litellm_params["api_key"] + ) + + if litellm_params["api_base"]: + if litellm_params["api_base"].startswith("os.environ/"): + litellm_params["api_base"] = litellm.get_secret( + litellm_params["api_base"] + ) + + # Init guardrail CustomLoggerClass + if litellm_params["guardrail"] == "aporia": + from guardrail_hooks.aporia_ai import _ENTERPRISE_Aporia + + _aporia_callback = _ENTERPRISE_Aporia( + api_base=litellm_params["api_base"], + api_key=litellm_params["api_key"], + guardrail_name=guardrail["guardrail_name"], + event_hook=litellm_params["mode"], + ) + litellm.callbacks.append(_aporia_callback) # type: ignore + elif litellm_params["guardrail"] == "lakera": + from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( + _ENTERPRISE_lakeraAI_Moderation, + ) + + _lakera_callback = _ENTERPRISE_lakeraAI_Moderation() + litellm.callbacks.append(_lakera_callback) # type: ignore + + parsed_guardrail = Guardrail( + guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params + ) + + guardrail_list.append(parsed_guardrail) + guardrail_name = guardrail["guardrail_name"] + + # pretty print guardrail_list in green + print(f"\nGuardrail List:{guardrail_list}\n") # noqa diff --git a/litellm/proxy/hooks/azure_content_safety.py b/litellm/proxy/hooks/azure_content_safety.py index 972ac99928..ccadafaf2e 100644 --- a/litellm/proxy/hooks/azure_content_safety.py +++ b/litellm/proxy/hooks/azure_content_safety.py @@ -1,11 +1,16 @@ -from litellm.integrations.custom_logger import CustomLogger -from litellm.caching import DualCache -from litellm.proxy._types import UserAPIKeyAuth -import litellm, traceback, sys, uuid -from fastapi import HTTPException -from litellm._logging import verbose_proxy_logger +import sys +import traceback +import uuid from typing import Optional +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + class _PROXY_AzureContentSafety( CustomLogger @@ -15,12 +20,12 @@ class _PROXY_AzureContentSafety( def __init__(self, endpoint, api_key, thresholds=None): try: from azure.ai.contentsafety.aio import ContentSafetyClient - from azure.core.credentials import AzureKeyCredential from azure.ai.contentsafety.models import ( - TextCategory, AnalyzeTextOptions, AnalyzeTextOutputType, + TextCategory, ) + from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import HttpResponseError except Exception as e: raise Exception( @@ -132,6 +137,7 @@ class _PROXY_AzureContentSafety( async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index 4bf08998a4..57985e9a69 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -254,7 +254,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): return None async def async_post_call_success_hook( - self, user_api_key_dict: UserAPIKeyAuth, response + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response ): try: if isinstance(response, ModelResponse): @@ -287,7 +287,9 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): return response return await super().async_post_call_success_hook( - user_api_key_dict, response + data=data, + user_api_key_dict=user_api_key_dict, + response=response, ) except Exception as e: verbose_proxy_logger.exception( diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py index 933d925507..6af7e3d1e5 100644 --- a/litellm/proxy/hooks/presidio_pii_masking.py +++ b/litellm/proxy/hooks/presidio_pii_masking.py @@ -322,6 +322,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): async def async_post_call_success_hook( self, + data: dict, user_api_key_dict: UserAPIKeyAuth, response: Union[ModelResponse, EmbeddingResponse, ImageResponse], ): diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 85211f943d..78f3e09492 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -316,9 +316,20 @@ async def add_litellm_data_to_request( for k, v in callback_settings_obj.callback_vars.items(): data[k] = v + # Guardrails + move_guardrails_to_metadata( + data=data, _metadata_variable_name=_metadata_variable_name + ) + return data +def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str): + if "guardrails" in data: + data[_metadata_variable_name]["guardrails"] = data["guardrails"] + del data["guardrails"] + + def add_provider_specific_headers_to_request( data: dict, headers: dict, diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index e08be88aad..3c1c64292e 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,50 +1,20 @@ model_list: - - model_name: gpt-4 + - model_name: fake-openai-endpoint litellm_params: model: openai/fake api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/ - model_info: - access_groups: ["beta-models"] - - model_name: fireworks-llama-v3-70b-instruct - litellm_params: - model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct - api_key: "os.environ/FIREWORKS" - model_info: - access_groups: ["beta-models"] - - model_name: "*" - litellm_params: - model: "*" - - model_name: "*" - litellm_params: - model: openai/* - api_key: os.environ/OPENAI_API_KEY - - model_name: mistral-small-latest - litellm_params: - model: mistral/mistral-small-latest - api_key: "os.environ/MISTRAL_API_KEY" - - model_name: bedrock-anthropic - litellm_params: - model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 - - model_name: gemini-1.5-pro-001 - litellm_params: - model: vertex_ai_beta/gemini-1.5-pro-001 - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" - vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" - # Add path to service account.json -default_vertex_config: - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" - vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json - - -general_settings: - master_key: sk-1234 - alerting: ["slack"] - -litellm_settings: - fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}] - success_callback: ["langfuse", "prometheus"] - langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"] +guardrails: + - guardrail_name: "aporia-pre-guard" + litellm_params: + guardrail: aporia # supported values: "aporia", "bedrock", "lakera" + mode: "post_call" + api_key: os.environ/APORIA_API_KEY_1 + api_base: os.environ/APORIA_API_BASE_1 + - guardrail_name: "aporia-post-guard" + litellm_params: + guardrail: aporia # supported values: "aporia", "bedrock", "lakera" + mode: "post_call" + api_key: os.environ/APORIA_API_KEY_2 + api_base: os.environ/APORIA_API_BASE_2 \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 3820f0ea33..0fbf10a2b9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -149,6 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import ( show_missing_vars_in_env, ) from litellm.proxy.common_utils.callback_utils import ( + get_applied_guardrails_header, get_remaining_tokens_and_requests_from_request_data, initialize_callbacks_on_proxy, ) @@ -168,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import ( ) from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config -from litellm.proxy.guardrails.init_guardrails import initialize_guardrails +from litellm.proxy.guardrails.init_guardrails import ( + init_guardrails_v2, + initialize_guardrails, +) from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_endpoints._health_endpoints import router as health_router from litellm.proxy.hooks.prompt_injection_detection import ( @@ -539,6 +543,10 @@ def get_custom_headers( ) headers.update(remaining_tokens_header) + applied_guardrails = get_applied_guardrails_header(request_data) + if applied_guardrails: + headers.update(applied_guardrails) + try: return { key: value for key, value in headers.items() if value not in exclude_values @@ -1937,6 +1945,11 @@ class ProxyConfig: async_only_mode=True # only init async clients ), ) # type:ignore + + # Guardrail settings + guardrails_v2 = config.get("guardrails", None) + if guardrails_v2: + init_guardrails_v2(all_guardrails=guardrails_v2) return router, router.get_model_list(), general_settings def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo: @@ -3139,7 +3152,7 @@ async def chat_completion( ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( - user_api_key_dict=user_api_key_dict, response=response + data=data, user_api_key_dict=user_api_key_dict, response=response ) hidden_params = ( @@ -3353,6 +3366,11 @@ async def completion( media_type="text/event-stream", headers=custom_headers, ) + ### CALL HOOKS ### - modify outgoing data + response = await proxy_logging_obj.post_call_success_hook( + data=data, user_api_key_dict=user_api_key_dict, response=response + ) + fastapi_response.headers.update( get_custom_headers( user_api_key_dict=user_api_key_dict, diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d1d17d0ef5..a2b09b4e69 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -432,12 +432,11 @@ class ProxyLogging: """ Runs the CustomLogger's async_moderation_hook() """ - new_data = safe_deep_copy(data) for callback in litellm.callbacks: try: if isinstance(callback, CustomLogger): await callback.async_moderation_hook( - data=new_data, + data=data, user_api_key_dict=user_api_key_dict, call_type=call_type, ) @@ -717,6 +716,7 @@ class ProxyLogging: async def post_call_success_hook( self, + data: dict, response: Union[ModelResponse, EmbeddingResponse, ImageResponse], user_api_key_dict: UserAPIKeyAuth, ): @@ -738,7 +738,9 @@ class ProxyLogging: _callback = callback # type: ignore if _callback is not None and isinstance(_callback, CustomLogger): await _callback.async_post_call_success_hook( - user_api_key_dict=user_api_key_dict, response=response + user_api_key_dict=user_api_key_dict, + data=data, + response=response, ) except Exception as e: raise e diff --git a/litellm/tests/test_azure_content_safety.py b/litellm/tests/test_azure_content_safety.py index 7b040fb252..dc80c163c3 100644 --- a/litellm/tests/test_azure_content_safety.py +++ b/litellm/tests/test_azure_content_safety.py @@ -1,8 +1,13 @@ # What is this? ## Unit test for azure content safety -import sys, os, asyncio, time, random -from datetime import datetime +import asyncio +import os +import random +import sys +import time import traceback +from datetime import datetime + from dotenv import load_dotenv from fastapi import HTTPException @@ -13,11 +18,12 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest + import litellm from litellm import Router, mock_completion -from litellm.proxy.utils import ProxyLogging -from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.utils import ProxyLogging @pytest.mark.asyncio @@ -177,7 +183,13 @@ async def test_strict_output_filtering_01(): with pytest.raises(HTTPException) as exc_info: await azure_content_safety.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [ + {"role": "system", "content": "You are an helpfull assistant"} + ] + }, + response=response, ) assert exc_info.value.detail["source"] == "output" @@ -216,7 +228,11 @@ async def test_strict_output_filtering_02(): ) await azure_content_safety.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [{"role": "system", "content": "You are an helpfull assistant"}] + }, + response=response, ) @@ -251,7 +267,11 @@ async def test_loose_output_filtering_01(): ) await azure_content_safety.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [{"role": "system", "content": "You are an helpfull assistant"}] + }, + response=response, ) @@ -286,5 +306,9 @@ async def test_loose_output_filtering_02(): ) await azure_content_safety.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [{"role": "system", "content": "You are an helpfull assistant"}] + }, + response=response, ) diff --git a/litellm/tests/test_presidio_masking.py b/litellm/tests/test_presidio_masking.py index 193fcf113b..35a03ea5e3 100644 --- a/litellm/tests/test_presidio_masking.py +++ b/litellm/tests/test_presidio_masking.py @@ -88,7 +88,11 @@ async def test_output_parsing(): mock_response="Hello ! How can I assist you today?", ) new_response = await pii_masking.async_post_call_success_hook( - user_api_key_dict=UserAPIKeyAuth(), response=response + user_api_key_dict=UserAPIKeyAuth(), + data={ + "messages": [{"role": "system", "content": "You are an helpfull assistant"}] + }, + response=response, ) assert ( diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 0296d8de4a..cd9f76f171 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TypedDict from pydantic import BaseModel, ConfigDict from typing_extensions import Required, TypedDict @@ -63,3 +63,26 @@ class GuardrailItem(BaseModel): enabled_roles=enabled_roles, callback_args=callback_args, ) + + +# Define the TypedDicts +class LitellmParams(TypedDict): + guardrail: str + mode: str + api_key: str + api_base: Optional[str] + + +class Guardrail(TypedDict): + guardrail_name: str + litellm_params: LitellmParams + + +class guardrailConfig(TypedDict): + guardrails: List[Guardrail] + + +class GuardrailEventHooks(str, Enum): + pre_call = "pre_call" + post_call = "post_call" + during_call = "during_call" diff --git a/tests/otel_tests/test_guardrails.py b/tests/otel_tests/test_guardrails.py new file mode 100644 index 0000000000..c48a5ba79b --- /dev/null +++ b/tests/otel_tests/test_guardrails.py @@ -0,0 +1,118 @@ +import pytest +import asyncio +import aiohttp, openai +from openai import OpenAI, AsyncOpenAI +from typing import Optional, List, Union +import uuid + + +async def chat_completion( + session, + key, + messages, + model: Union[str, List] = "gpt-4", + guardrails: Optional[List] = None, +): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + + data = { + "model": model, + "messages": messages, + "guardrails": [ + "aporia-post-guard", + "aporia-pre-guard", + ], # default guardrails for all tests + } + + if guardrails is not None: + data["guardrails"] = guardrails + + print("data=", data) + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + return response_text + + # response headers + response_headers = response.headers + print("response headers=", response_headers) + + return await response.json(), response_headers + + +@pytest.mark.asyncio +async def test_llm_guard_triggered_safe_request(): + """ + - Tests a request where no content mod is triggered + - Assert that the guardrails applied are returned in the response headers + """ + async with aiohttp.ClientSession() as session: + response, headers = await chat_completion( + session, + "sk-1234", + model="fake-openai-endpoint", + messages=[{"role": "user", "content": f"Hello what's the weather"}], + ) + await asyncio.sleep(3) + + print("response=", response, "response headers", headers) + + assert "x-litellm-applied-guardrails" in headers + + assert ( + headers["x-litellm-applied-guardrails"] + == "aporia-pre-guard,aporia-post-guard" + ) + + +@pytest.mark.asyncio +async def test_llm_guard_triggered(): + """ + - Tests a request where no content mod is triggered + - Assert that the guardrails applied are returned in the response headers + """ + async with aiohttp.ClientSession() as session: + try: + response, headers = await chat_completion( + session, + "sk-1234", + model="fake-openai-endpoint", + messages=[ + {"role": "user", "content": f"Hello my name is ishaan@berri.ai"} + ], + ) + pytest.fail("Should have thrown an exception") + except Exception as e: + print(e) + assert "Aporia detected and blocked PII" in str(e) + + +@pytest.mark.asyncio +async def test_no_llm_guard_triggered(): + """ + - Tests a request where no content mod is triggered + - Assert that the guardrails applied are returned in the response headers + """ + async with aiohttp.ClientSession() as session: + response, headers = await chat_completion( + session, + "sk-1234", + model="fake-openai-endpoint", + messages=[{"role": "user", "content": f"Hello what's the weather"}], + guardrails=[], + ) + await asyncio.sleep(3) + + print("response=", response, "response headers", headers) + + assert "x-litellm-applied-guardrails" not in headers