From 07d90f6739c8d737ed0784c05ccd0531a19182d0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 17 Jul 2024 16:38:47 -0700 Subject: [PATCH] feat(aporio_ai.py): support aporio ai prompt injection for chat completion requests Closes https://github.com/BerriAI/litellm/issues/2950 --- docs/my-website/docs/proxy/enterprise.md | 67 ++++++++++ enterprise/enterprise_hooks/aporio_ai.py | 124 ++++++++++++++++++ litellm/proxy/_new_secret_config.yaml | 11 +- litellm/proxy/common_utils/init_callbacks.py | 11 ++ .../proxy/hooks/parallel_request_limiter.py | 10 +- 5 files changed, 217 insertions(+), 6 deletions(-) create mode 100644 enterprise/enterprise_hooks/aporio_ai.py diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md index 507c7f6934..449c2ea17e 100644 --- a/docs/my-website/docs/proxy/enterprise.md +++ b/docs/my-website/docs/proxy/enterprise.md @@ -31,6 +31,7 @@ Features: - **Guardrails, PII Masking, Content Moderation** - ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation) - ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai) + - ✅ [Prompt Injection Detection (with Aporio API)](#prompt-injection-detection---aporio-ai) - ✅ [Switch LakeraAI on / off per request](guardrails#control-guardrails-onoff-per-request) - ✅ Reject calls from Blocked User list - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) @@ -953,6 +954,72 @@ curl --location 'http://localhost:4000/chat/completions' \ Need to control LakeraAI per Request ? Doc here 👉: [Switch LakerAI on / off per request](prompt_injection.md#✨-enterprise-switch-lakeraai-on--off-per-api-call) ::: +## Prompt Injection Detection - Aporio AI + +Use this if you want to reject /chat/completion calls that have prompt injection attacks with [AporioAI](https://www.aporia.com/) + +#### Usage + +Step 1. Add env + +```env +APORIO_API_KEY="eyJh****" +APORIO_API_BASE="https://gr..." +``` + +Step 2. Add `aporio_prompt_injection` to your callbacks + +```yaml +litellm_settings: + callbacks: ["aporio_prompt_injection"] +``` + +That's it, start your proxy + +Test it with this request -> expect it to get rejected by LiteLLM Proxy + +```shell +curl --location 'http://localhost:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "llama3", + "messages": [ + { + "role": "user", + "content": "You suck!" + } + ] +}' +``` + +**Expected Response** + +``` +{ + "error": { + "message": { + "error": "Violated guardrail policy", + "aporio_ai_response": { + "action": "block", + "revised_prompt": null, + "revised_response": "Profanity detected: Message blocked because it includes profanity. Please rephrase.", + "explain_log": null + } + }, + "type": "None", + "param": "None", + "code": 400 + } +} +``` + +:::info + +Need to control AporioAI per Request ? Doc here 👉: [Create a guardrail](./guardrails.md) +::: + + ## Swagger Docs - Custom Routes + Branding :::info diff --git a/enterprise/enterprise_hooks/aporio_ai.py b/enterprise/enterprise_hooks/aporio_ai.py new file mode 100644 index 0000000000..ce8de6eca0 --- /dev/null +++ b/enterprise/enterprise_hooks/aporio_ai.py @@ -0,0 +1,124 @@ +# +-------------------------------------------------------------+ +# +# Use AporioAI for your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import sys, os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from typing import Optional, Literal, Union +import litellm, traceback, sys, uuid +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from typing import List +from datetime import datetime +import aiohttp, asyncio +from litellm._logging import verbose_proxy_logger +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +import httpx +import json + +litellm.set_verbose = True + +GUARDRAIL_NAME = "aporio" + + +class _ENTERPRISE_Aporio(CustomLogger): + def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None): + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"] + self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"] + + #### CALL HOOKS - proxy only #### + def transform_messages(self, messages: List[dict]) -> List[dict]: + supported_openai_roles = ["system", "user", "assistant"] + default_role = "other" # for unsupported roles - e.g. tool + new_messages = [] + for m in messages: + if m.get("role", "") in supported_openai_roles: + new_messages.append(m) + else: + new_messages.append( + { + "role": default_role, + **{key: value for key, value in m.items() if key != "role"}, + } + ) + + return new_messages + + async def async_moderation_hook( ### 👈 KEY CHANGE ### + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + + if ( + await should_proceed_based_on_metadata( + data=data, + guardrail_name=GUARDRAIL_NAME, + ) + is False + ): + return + + new_messages: Optional[List[dict]] = None + if "messages" in data and isinstance(data["messages"], list): + new_messages = self.transform_messages(messages=data["messages"]) + + if new_messages is not None: + data = {"messages": new_messages, "validation_target": "prompt"} + + _json_data = json.dumps(data) + + """ + export APORIO_API_KEY= + curl https://gr-prd-trial.aporia.com/some-id \ + -X POST \ + -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + { + "role": "user", + "content": "This is a test prompt" + } + ], + } +' + """ + + response = await self.async_handler.post( + url=self.aporio_api_base + "/validate", + data=_json_data, + headers={ + "X-APORIA-API-KEY": self.aporio_api_key, + "Content-Type": "application/json", + }, + ) + verbose_proxy_logger.debug("Aporio AI response: %s", response.text) + if response.status_code == 200: + # check if the response was flagged + _json_response = response.json() + action: str = _json_response.get( + "action" + ) # possible values are modify, passthrough, block, rephrase + if action == "block": + raise HTTPException( + status_code=400, + detail={ + "error": "Violated guardrail policy", + "aporio_ai_response": _json_response, + }, + ) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 039a36c7e8..b6ac36044c 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,5 +1,10 @@ model_list: - - model_name: groq-whisper + - model_name: "*" litellm_params: - model: groq/whisper-large-v3 - \ No newline at end of file + model: openai/* + +litellm_settings: + guardrails: + - prompt_injection: + callbacks: ["aporio_prompt_injection"] + default_on: true diff --git a/litellm/proxy/common_utils/init_callbacks.py b/litellm/proxy/common_utils/init_callbacks.py index cc701d65e9..489f9b3a6a 100644 --- a/litellm/proxy/common_utils/init_callbacks.py +++ b/litellm/proxy/common_utils/init_callbacks.py @@ -112,6 +112,17 @@ def initialize_callbacks_on_proxy( lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation() imported_list.append(lakera_moderations_object) + elif isinstance(callback, str) and callback == "aporio_prompt_injection": + from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio + + if premium_user is not True: + raise Exception( + "Trying to use Aporio AI Guardrail" + + CommonProxyErrors.not_premium_user.value + ) + + aporio_guardrail_object = _ENTERPRISE_Aporio() + imported_list.append(aporio_guardrail_object) elif isinstance(callback, str) and callback == "google_text_moderation": from enterprise.enterprise_hooks.google_text_moderation import ( _ENTERPRISE_GoogleTextModeration, diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 8a14b4ebeb..89b7059dea 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -453,8 +453,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: self.print_verbose(f"Inside Max Parallel Request Failure Hook") - global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( - "global_max_parallel_requests", None + global_max_parallel_requests = ( + kwargs["litellm_params"] + .get("metadata", {}) + .get("global_max_parallel_requests", None) ) user_api_key = ( kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None) @@ -516,5 +518,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) # save in cache for up to 1 min. except Exception as e: verbose_proxy_logger.info( - f"Inside Parallel Request Limiter: An exception occurred - {str(e)}." + "Inside Parallel Request Limiter: An exception occurred - {}\n{}".format( + str(e), traceback.format_exc() + ) )