diff --git a/enterprise/enterprise_hooks/lakera_ai.py b/enterprise/enterprise_hooks/lakera_ai.py new file mode 100644 index 000000000..bef009dc5 --- /dev/null +++ b/enterprise/enterprise_hooks/lakera_ai.py @@ -0,0 +1,117 @@ +# +-------------------------------------------------------------+ +# +# Use lakeraAI /moderations for your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import sys, os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from typing import Optional, Literal, Union +import litellm, traceback, sys, uuid +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger +from litellm.utils import ( + ModelResponse, + EmbeddingResponse, + ImageResponse, + StreamingChoices, +) +from datetime import datetime +import aiohttp, asyncio +from litellm._logging import verbose_proxy_logger +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +import httpx +import json + +litellm.set_verbose = True + + +class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): + def __init__(self): + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + self.lakera_api_key = os.environ["LAKERA_API_KEY"] + pass + + #### CALL HOOKS - proxy only #### + + async def async_moderation_hook( ### 👈 KEY CHANGE ### + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + if "messages" in data and isinstance(data["messages"], list): + text = "" + for m in data["messages"]: # assume messages is a list + if "content" in m and isinstance(m["content"], str): + text += m["content"] + + # https://platform.lakera.ai/account/api-keys + data = {"input": text} + + _json_data = json.dumps(data) + + """ + export LAKERA_GUARD_API_KEY= + curl https://api.lakera.ai/v1/prompt_injection \ + -X POST \ + -H "Authorization: Bearer $LAKERA_GUARD_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"input": "Your content goes here"}' + """ + + response = await self.async_handler.post( + url="https://api.lakera.ai/v1/prompt_injection", + data=_json_data, + headers={ + "Authorization": "Bearer " + self.lakera_api_key, + "Content-Type": "application/json", + }, + ) + verbose_proxy_logger.debug("Lakera AI response: %s", response.text) + if response.status_code == 200: + # check if the response was flagged + """ + Example Response from Lakera AI + + { + "model": "lakera-guard-1", + "results": [ + { + "categories": { + "prompt_injection": true, + "jailbreak": false + }, + "category_scores": { + "prompt_injection": 1.0, + "jailbreak": 0.0 + }, + "flagged": true, + "payload": {} + } + ], + "dev_info": { + "git_revision": "784489d3", + "git_timestamp": "2024-05-22T16:51:26+00:00" + } + } + """ + _json_response = response.json() + _results = _json_response.get("results", []) + flagged = _results.get("flagged", False) + + if flagged == True: + raise HTTPException( + status_code=400, detail={"error": "Violated content safety policy"} + ) + + pass diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4045c7d91..a8c2232bb 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2325,6 +2325,18 @@ class ProxyConfig: _ENTERPRISE_OpenAI_Moderation() ) imported_list.append(openai_moderations_object) + elif ( + isinstance(callback, str) + and callback == "lakera_prompt_injection" + ): + from enterprise.enterprise_hooks.lakera_ai import ( + _ENTERPRISE_lakeraAI_Moderation, + ) + + lakera_moderations_object = ( + _ENTERPRISE_lakeraAI_Moderation() + ) + imported_list.append(lakera_moderations_object) elif ( isinstance(callback, str) and callback == "google_text_moderation"