mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(proxy_server.py): enable llm api based prompt injection checks
run user calls through an llm api to check for prompt injection attacks. This happens in parallel to th e actual llm call using `async_moderation_hook`
This commit is contained in:
parent
f24d3ffdb6
commit
d91f9a9f50
11 changed files with 271 additions and 24 deletions
|
@ -107,6 +107,9 @@ from litellm.caching import DualCache
|
|||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -284,6 +287,7 @@ proxy_batch_write_at = 60 # in seconds
|
|||
litellm_master_key_hash = None
|
||||
disable_spend_logs = False
|
||||
jwt_handler = JWTHandler()
|
||||
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
|
||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||
### REDIS QUEUE ###
|
||||
|
@ -1657,7 +1661,7 @@ class ProxyConfig:
|
|||
"""
|
||||
Load config values into proxy global state
|
||||
"""
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj
|
||||
|
||||
# Load existing config
|
||||
config = await self.get_config(config_file_path=config_file_path)
|
||||
|
@ -1822,8 +1826,21 @@ class ProxyConfig:
|
|||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
|
||||
prompt_injection_params = None
|
||||
if "prompt_injection_params" in litellm_settings:
|
||||
prompt_injection_params_in_config = (
|
||||
litellm_settings["prompt_injection_params"]
|
||||
)
|
||||
prompt_injection_params = (
|
||||
LiteLLMPromptInjectionParams(
|
||||
**prompt_injection_params_in_config
|
||||
)
|
||||
)
|
||||
|
||||
prompt_injection_detection_obj = (
|
||||
_OPTIONAL_PromptInjectionDetection()
|
||||
_OPTIONAL_PromptInjectionDetection(
|
||||
prompt_injection_params=prompt_injection_params,
|
||||
)
|
||||
)
|
||||
imported_list.append(prompt_injection_detection_obj)
|
||||
elif (
|
||||
|
@ -2592,6 +2609,8 @@ async def startup_event():
|
|||
_run_background_health_check()
|
||||
) # start the background health check coroutine.
|
||||
|
||||
if prompt_injection_detection_obj is not None:
|
||||
prompt_injection_detection_obj.update_environment(router=llm_router)
|
||||
verbose_proxy_logger.debug(f"prisma client - {prisma_client}")
|
||||
if prisma_client is not None:
|
||||
await prisma_client.connect()
|
||||
|
@ -3011,7 +3030,9 @@ async def chat_completion(
|
|||
)
|
||||
|
||||
tasks = []
|
||||
tasks.append(proxy_logging_obj.during_call_hook(data=data))
|
||||
tasks.append(
|
||||
proxy_logging_obj.during_call_hook(data=data, call_type="completion")
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue