mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(proxy_server.py): enable llm api based prompt injection checks
run user calls through an llm api to check for prompt injection attacks. This happens in parallel to th e actual llm call using `async_moderation_hook`
This commit is contained in:
parent
f24d3ffdb6
commit
d91f9a9f50
11 changed files with 271 additions and 24 deletions
|
@ -10,10 +10,11 @@
|
|||
from typing import Optional, Literal
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy._types import UserAPIKeyAuth, LiteLLMPromptInjectionParams
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.utils import get_formatted_prompt
|
||||
from litellm.llms.prompt_templates.factory import prompt_injection_detection_default_pt
|
||||
from fastapi import HTTPException
|
||||
import json, traceback, re
|
||||
from difflib import SequenceMatcher
|
||||
|
@ -22,7 +23,13 @@ from typing import List
|
|||
|
||||
class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
|
||||
):
|
||||
self.prompt_injection_params = prompt_injection_params
|
||||
self.llm_router: Optional[litellm.Router] = None
|
||||
|
||||
self.verbs = [
|
||||
"Ignore",
|
||||
"Disregard",
|
||||
|
@ -63,6 +70,30 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
def update_environment(self, router: Optional[litellm.Router] = None):
|
||||
self.llm_router = router
|
||||
|
||||
if (
|
||||
self.prompt_injection_params is not None
|
||||
and self.prompt_injection_params.llm_api_check == True
|
||||
):
|
||||
if self.llm_router is None:
|
||||
raise Exception(
|
||||
"PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
|
||||
)
|
||||
if (
|
||||
self.prompt_injection_params.llm_api_name is None
|
||||
or self.prompt_injection_params.llm_api_name
|
||||
not in self.llm_router.model_names
|
||||
):
|
||||
raise Exception(
|
||||
"PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
|
||||
)
|
||||
|
||||
def generate_injection_keywords(self) -> List[str]:
|
||||
combinations = []
|
||||
for verb in self.verbs:
|
||||
|
@ -127,9 +158,28 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
return data
|
||||
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
|
||||
|
||||
is_prompt_attack = self.check_user_input_similarity(
|
||||
user_input=formatted_prompt
|
||||
)
|
||||
is_prompt_attack = False
|
||||
|
||||
if self.prompt_injection_params is not None:
|
||||
# 1. check if heuristics check turned on
|
||||
if self.prompt_injection_params.heuristics_check == True:
|
||||
is_prompt_attack = self.check_user_input_similarity(
|
||||
user_input=formatted_prompt
|
||||
)
|
||||
if is_prompt_attack == True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Rejected message. This is a prompt injection attack."
|
||||
},
|
||||
)
|
||||
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
|
||||
if self.prompt_injection_params.vector_db_check == True:
|
||||
pass
|
||||
else:
|
||||
is_prompt_attack = self.check_user_input_similarity(
|
||||
user_input=formatted_prompt
|
||||
)
|
||||
|
||||
if is_prompt_attack == True:
|
||||
raise HTTPException(
|
||||
|
@ -145,3 +195,62 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
call_type: (
|
||||
Literal["completion"] | Literal["embeddings"] | Literal["image_generation"]
|
||||
),
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
|
||||
)
|
||||
|
||||
if self.prompt_injection_params is None:
|
||||
return
|
||||
|
||||
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
|
||||
is_prompt_attack = False
|
||||
|
||||
prompt_injection_system_prompt = getattr(
|
||||
self.prompt_injection_params,
|
||||
"llm_api_system_prompt",
|
||||
prompt_injection_detection_default_pt(),
|
||||
)
|
||||
|
||||
# 3. check if llm api check turned on
|
||||
if (
|
||||
self.prompt_injection_params.llm_api_check == True
|
||||
and self.prompt_injection_params.llm_api_name is not None
|
||||
and self.llm_router is not None
|
||||
):
|
||||
# make a call to the llm api
|
||||
response = await self.llm_router.acompletion(
|
||||
model=self.prompt_injection_params.llm_api_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt_injection_system_prompt,
|
||||
},
|
||||
{"role": "user", "content": formatted_prompt},
|
||||
],
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(f"Received LLM Moderation response: {response}")
|
||||
|
||||
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||
response.choices, litellm.Choices
|
||||
):
|
||||
if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore
|
||||
is_prompt_attack = True
|
||||
|
||||
if is_prompt_attack == True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Rejected message. This is a prompt injection attack."
|
||||
},
|
||||
)
|
||||
|
||||
return is_prompt_attack
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue