feat(proxy_server.py): enable llm api based prompt injection checks

run user calls through an llm api to check for prompt injection attacks. This happens in parallel to th
e actual llm call using `async_moderation_hook`
This commit is contained in:
Krrish Dholakia 2024-03-20 22:43:42 -07:00
parent f24d3ffdb6
commit d91f9a9f50
11 changed files with 271 additions and 24 deletions

View file

@ -107,6 +107,9 @@ from litellm.caching import DualCache
from litellm.proxy.health_check import perform_health_check
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
try:
from litellm._version import version
@ -284,6 +287,7 @@ proxy_batch_write_at = 60 # in seconds
litellm_master_key_hash = None
disable_spend_logs = False
jwt_handler = JWTHandler()
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
### REDIS QUEUE ###
@ -1657,7 +1661,7 @@ class ProxyConfig:
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
@ -1822,8 +1826,21 @@ class ProxyConfig:
_OPTIONAL_PromptInjectionDetection,
)
prompt_injection_params = None
if "prompt_injection_params" in litellm_settings:
prompt_injection_params_in_config = (
litellm_settings["prompt_injection_params"]
)
prompt_injection_params = (
LiteLLMPromptInjectionParams(
**prompt_injection_params_in_config
)
)
prompt_injection_detection_obj = (
_OPTIONAL_PromptInjectionDetection()
_OPTIONAL_PromptInjectionDetection(
prompt_injection_params=prompt_injection_params,
)
)
imported_list.append(prompt_injection_detection_obj)
elif (
@ -2592,6 +2609,8 @@ async def startup_event():
_run_background_health_check()
) # start the background health check coroutine.
if prompt_injection_detection_obj is not None:
prompt_injection_detection_obj.update_environment(router=llm_router)
verbose_proxy_logger.debug(f"prisma client - {prisma_client}")
if prisma_client is not None:
await prisma_client.connect()
@ -3011,7 +3030,9 @@ async def chat_completion(
)
tasks = []
tasks.append(proxy_logging_obj.during_call_hook(data=data))
tasks.append(
proxy_logging_obj.during_call_hook(data=data, call_type="completion")
)
start_time = time.time()