diff --git a/enterprise/enterprise_hooks/prompt_injection_detection.py b/enterprise/enterprise_hooks/prompt_injection_detection.py new file mode 100644 index 000000000..ebeb19c6e --- /dev/null +++ b/enterprise/enterprise_hooks/prompt_injection_detection.py @@ -0,0 +1,144 @@ +# +------------------------------------+ +# +# Prompt Injection Detection +# +# +------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan +## Reject a call if it contains a prompt injection attack. + + +from typing import Optional, Literal +import litellm +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from litellm._logging import verbose_proxy_logger +from litellm.utils import get_formatted_prompt +from fastapi import HTTPException +import json, traceback, re +from difflib import SequenceMatcher +from typing import List + + +class _ENTERPRISE_PromptInjectionDetection(CustomLogger): + # Class variables or attributes + def __init__(self): + self.verbs = [ + "Ignore", + "Disregard", + "Skip", + "Forget", + "Neglect", + "Overlook", + "Omit", + "Bypass", + "Pay no attention to", + "Do not follow", + "Do not obey", + ] + self.adjectives = [ + "", + "prior", + "previous", + "preceding", + "above", + "foregoing", + "earlier", + "initial", + ] + self.prepositions = [ + "", + "and start over", + "and start anew", + "and begin afresh", + "and start from scratch", + ] + + def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): + if level == "INFO": + verbose_proxy_logger.info(print_statement) + elif level == "DEBUG": + verbose_proxy_logger.debug(print_statement) + + if litellm.set_verbose is True: + print(print_statement) # noqa + + def generate_injection_keywords(self) -> List[str]: + combinations = [] + for verb in self.verbs: + for adj in self.adjectives: + for prep in self.prepositions: + phrase = " ".join(filter(None, [verb, adj, prep])).strip() + combinations.append(phrase.lower()) + return combinations + + def check_user_input_similarity( + self, user_input: str, similarity_threshold: float = 0.7 + ) -> bool: + user_input_lower = user_input.lower() + keywords = self.generate_injection_keywords() + + for keyword in keywords: + # Calculate the length of the keyword to extract substrings of the same length from user input + keyword_length = len(keyword) + + for i in range(len(user_input_lower) - keyword_length + 1): + # Extract a substring of the same length as the keyword + substring = user_input_lower[i : i + keyword_length] + + # Calculate similarity + match_ratio = SequenceMatcher(None, substring, keyword).ratio() + if match_ratio > similarity_threshold: + self.print_verbose( + print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}", + level="INFO", + ) + return True # Found a highly similar substring + return False # No substring crossed the threshold + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, # "completion", "embeddings", "image_generation", "moderation" + ): + try: + """ + - check if user id part of call + - check if user id part of blocked list + """ + self.print_verbose(f"Inside Prompt Injection Detection Pre-Call Hook") + try: + assert call_type in [ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ] + except Exception as e: + self.print_verbose( + f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" + ) + return data + formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore + + is_prompt_attack = self.check_user_input_similarity( + user_input=formatted_prompt + ) + + if is_prompt_attack == True: + raise HTTPException( + status_code=400, + detail={ + "error": "Rejected message. This is a prompt injection attack." + }, + ) + + return data + + except HTTPException as e: + raise e + except Exception as e: + traceback.print_exc() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b5b68df0c..3bba37e96 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1665,6 +1665,18 @@ class ProxyConfig: banned_keywords_obj = _ENTERPRISE_BannedKeywords() imported_list.append(banned_keywords_obj) + elif ( + isinstance(callback, str) + and callback == "detect_prompt_injection" + ): + from litellm.proxy.enterprise.enterprise_hooks.prompt_injection_detection import ( + _ENTERPRISE_PromptInjectionDetection, + ) + + prompt_injection_detection_obj = ( + _ENTERPRISE_PromptInjectionDetection() + ) + imported_list.append(prompt_injection_detection_obj) else: imported_list.append( get_instance_fn( diff --git a/litellm/utils.py b/litellm/utils.py index 8c62a2222..3fb961c05 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5301,6 +5301,40 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): ] +def get_formatted_prompt( + data: dict, + call_type: Literal[ + "completion", + "embedding", + "image_generation", + "audio_transcription", + "moderation", + ], +) -> str: + """ + Extracts the prompt from the input data based on the call type. + + Returns a string. + """ + prompt = "" + if call_type == "completion": + for m in data["messages"]: + if "content" in m and isinstance(m["content"], str): + prompt += m["content"] + elif call_type == "embedding" or call_type == "moderation": + if isinstance(data["input"], str): + prompt = data["input"] + elif isinstance(data["input"], list): + for m in data["input"]: + prompt += m + elif call_type == "image_generation": + prompt = data["prompt"] + elif call_type == "audio_transcription": + if "prompt" in data: + prompt = data["prompt"] + return prompt + + def get_llm_provider( model: str, custom_llm_provider: Optional[str] = None,