feat(proxy_server.py): enable llm api based prompt injection checks

run user calls through an llm api to check for prompt injection attacks. This happens in parallel to th
e actual llm call using `async_moderation_hook`
This commit is contained in:
Krrish Dholakia 2024-03-20 22:43:42 -07:00
parent f24d3ffdb6
commit d91f9a9f50
11 changed files with 271 additions and 24 deletions

View file

@ -10,10 +10,11 @@
from typing import Optional, Literal
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy._types import UserAPIKeyAuth, LiteLLMPromptInjectionParams
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from litellm.utils import get_formatted_prompt
from litellm.llms.prompt_templates.factory import prompt_injection_detection_default_pt
from fastapi import HTTPException
import json, traceback, re
from difflib import SequenceMatcher
@ -22,7 +23,13 @@ from typing import List
class _OPTIONAL_PromptInjectionDetection(CustomLogger):
# Class variables or attributes
def __init__(self):
def __init__(
self,
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
):
self.prompt_injection_params = prompt_injection_params
self.llm_router: Optional[litellm.Router] = None
self.verbs = [
"Ignore",
"Disregard",
@ -63,6 +70,30 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
if litellm.set_verbose is True:
print(print_statement) # noqa
def update_environment(self, router: Optional[litellm.Router] = None):
self.llm_router = router
if (
self.prompt_injection_params is not None
and self.prompt_injection_params.llm_api_check == True
):
if self.llm_router is None:
raise Exception(
"PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
)
verbose_proxy_logger.debug(
f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
)
if (
self.prompt_injection_params.llm_api_name is None
or self.prompt_injection_params.llm_api_name
not in self.llm_router.model_names
):
raise Exception(
"PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
)
def generate_injection_keywords(self) -> List[str]:
combinations = []
for verb in self.verbs:
@ -127,9 +158,28 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
return data
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
is_prompt_attack = False
if self.prompt_injection_params is not None:
# 1. check if heuristics check turned on
if self.prompt_injection_params.heuristics_check == True:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
if is_prompt_attack == True:
raise HTTPException(
status_code=400,
detail={
"error": "Rejected message. This is a prompt injection attack."
},
)
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
if self.prompt_injection_params.vector_db_check == True:
pass
else:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
if is_prompt_attack == True:
raise HTTPException(
@ -145,3 +195,62 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
raise e
except Exception as e:
traceback.print_exc()
async def async_moderation_hook(
self,
data: dict,
call_type: (
Literal["completion"] | Literal["embeddings"] | Literal["image_generation"]
),
):
verbose_proxy_logger.debug(
f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
)
if self.prompt_injection_params is None:
return
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
is_prompt_attack = False
prompt_injection_system_prompt = getattr(
self.prompt_injection_params,
"llm_api_system_prompt",
prompt_injection_detection_default_pt(),
)
# 3. check if llm api check turned on
if (
self.prompt_injection_params.llm_api_check == True
and self.prompt_injection_params.llm_api_name is not None
and self.llm_router is not None
):
# make a call to the llm api
response = await self.llm_router.acompletion(
model=self.prompt_injection_params.llm_api_name,
messages=[
{
"role": "system",
"content": prompt_injection_system_prompt,
},
{"role": "user", "content": formatted_prompt},
],
)
verbose_proxy_logger.debug(f"Received LLM Moderation response: {response}")
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices, litellm.Choices
):
if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore
is_prompt_attack = True
if is_prompt_attack == True:
raise HTTPException(
status_code=400,
detail={
"error": "Rejected message. This is a prompt injection attack."
},
)
return is_prompt_attack