forked from phoenix/litellm-mirror
Merge pull request #2498 from BerriAI/litellm_prompt_injection_detection
feat(prompt_injection_detection.py): support simple heuristic similarity check for prompt injection attacks
This commit is contained in:
commit
3e66b50602
3 changed files with 190 additions and 0 deletions
144
enterprise/enterprise_hooks/prompt_injection_detection.py
Normal file
144
enterprise/enterprise_hooks/prompt_injection_detection.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
# +------------------------------------+
|
||||||
|
#
|
||||||
|
# Prompt Injection Detection
|
||||||
|
#
|
||||||
|
# +------------------------------------+
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
## Reject a call if it contains a prompt injection attack.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Literal
|
||||||
|
import litellm
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.utils import get_formatted_prompt
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import json, traceback, re
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class _ENTERPRISE_PromptInjectionDetection(CustomLogger):
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
self.verbs = [
|
||||||
|
"Ignore",
|
||||||
|
"Disregard",
|
||||||
|
"Skip",
|
||||||
|
"Forget",
|
||||||
|
"Neglect",
|
||||||
|
"Overlook",
|
||||||
|
"Omit",
|
||||||
|
"Bypass",
|
||||||
|
"Pay no attention to",
|
||||||
|
"Do not follow",
|
||||||
|
"Do not obey",
|
||||||
|
]
|
||||||
|
self.adjectives = [
|
||||||
|
"",
|
||||||
|
"prior",
|
||||||
|
"previous",
|
||||||
|
"preceding",
|
||||||
|
"above",
|
||||||
|
"foregoing",
|
||||||
|
"earlier",
|
||||||
|
"initial",
|
||||||
|
]
|
||||||
|
self.prepositions = [
|
||||||
|
"",
|
||||||
|
"and start over",
|
||||||
|
"and start anew",
|
||||||
|
"and begin afresh",
|
||||||
|
"and start from scratch",
|
||||||
|
]
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
|
||||||
|
if level == "INFO":
|
||||||
|
verbose_proxy_logger.info(print_statement)
|
||||||
|
elif level == "DEBUG":
|
||||||
|
verbose_proxy_logger.debug(print_statement)
|
||||||
|
|
||||||
|
if litellm.set_verbose is True:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
|
||||||
|
def generate_injection_keywords(self) -> List[str]:
|
||||||
|
combinations = []
|
||||||
|
for verb in self.verbs:
|
||||||
|
for adj in self.adjectives:
|
||||||
|
for prep in self.prepositions:
|
||||||
|
phrase = " ".join(filter(None, [verb, adj, prep])).strip()
|
||||||
|
combinations.append(phrase.lower())
|
||||||
|
return combinations
|
||||||
|
|
||||||
|
def check_user_input_similarity(
|
||||||
|
self, user_input: str, similarity_threshold: float = 0.7
|
||||||
|
) -> bool:
|
||||||
|
user_input_lower = user_input.lower()
|
||||||
|
keywords = self.generate_injection_keywords()
|
||||||
|
|
||||||
|
for keyword in keywords:
|
||||||
|
# Calculate the length of the keyword to extract substrings of the same length from user input
|
||||||
|
keyword_length = len(keyword)
|
||||||
|
|
||||||
|
for i in range(len(user_input_lower) - keyword_length + 1):
|
||||||
|
# Extract a substring of the same length as the keyword
|
||||||
|
substring = user_input_lower[i : i + keyword_length]
|
||||||
|
|
||||||
|
# Calculate similarity
|
||||||
|
match_ratio = SequenceMatcher(None, substring, keyword).ratio()
|
||||||
|
if match_ratio > similarity_threshold:
|
||||||
|
self.print_verbose(
|
||||||
|
print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
|
||||||
|
level="INFO",
|
||||||
|
)
|
||||||
|
return True # Found a highly similar substring
|
||||||
|
return False # No substring crossed the threshold
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
- check if user id part of call
|
||||||
|
- check if user id part of blocked list
|
||||||
|
"""
|
||||||
|
self.print_verbose(f"Inside Prompt Injection Detection Pre-Call Hook")
|
||||||
|
try:
|
||||||
|
assert call_type in [
|
||||||
|
"completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
self.print_verbose(
|
||||||
|
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
|
||||||
|
|
||||||
|
is_prompt_attack = self.check_user_input_similarity(
|
||||||
|
user_input=formatted_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_prompt_attack == True:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Rejected message. This is a prompt injection attack."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
|
@ -1665,6 +1665,18 @@ class ProxyConfig:
|
||||||
|
|
||||||
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
|
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
|
||||||
imported_list.append(banned_keywords_obj)
|
imported_list.append(banned_keywords_obj)
|
||||||
|
elif (
|
||||||
|
isinstance(callback, str)
|
||||||
|
and callback == "detect_prompt_injection"
|
||||||
|
):
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.prompt_injection_detection import (
|
||||||
|
_ENTERPRISE_PromptInjectionDetection,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_injection_detection_obj = (
|
||||||
|
_ENTERPRISE_PromptInjectionDetection()
|
||||||
|
)
|
||||||
|
imported_list.append(prompt_injection_detection_obj)
|
||||||
else:
|
else:
|
||||||
imported_list.append(
|
imported_list.append(
|
||||||
get_instance_fn(
|
get_instance_fn(
|
||||||
|
|
|
@ -5301,6 +5301,40 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_formatted_prompt(
|
||||||
|
data: dict,
|
||||||
|
call_type: Literal[
|
||||||
|
"completion",
|
||||||
|
"embedding",
|
||||||
|
"image_generation",
|
||||||
|
"audio_transcription",
|
||||||
|
"moderation",
|
||||||
|
],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Extracts the prompt from the input data based on the call type.
|
||||||
|
|
||||||
|
Returns a string.
|
||||||
|
"""
|
||||||
|
prompt = ""
|
||||||
|
if call_type == "completion":
|
||||||
|
for m in data["messages"]:
|
||||||
|
if "content" in m and isinstance(m["content"], str):
|
||||||
|
prompt += m["content"]
|
||||||
|
elif call_type == "embedding" or call_type == "moderation":
|
||||||
|
if isinstance(data["input"], str):
|
||||||
|
prompt = data["input"]
|
||||||
|
elif isinstance(data["input"], list):
|
||||||
|
for m in data["input"]:
|
||||||
|
prompt += m
|
||||||
|
elif call_type == "image_generation":
|
||||||
|
prompt = data["prompt"]
|
||||||
|
elif call_type == "audio_transcription":
|
||||||
|
if "prompt" in data:
|
||||||
|
prompt = data["prompt"]
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def get_llm_provider(
|
def get_llm_provider(
|
||||||
model: str,
|
model: str,
|
||||||
custom_llm_provider: Optional[str] = None,
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue