diff --git a/litellm/proxy/common_utils/init_callbacks.py b/litellm/proxy/common_utils/init_callbacks.py new file mode 100644 index 000000000..6ff4601d9 --- /dev/null +++ b/litellm/proxy/common_utils/init_callbacks.py @@ -0,0 +1,217 @@ +from typing import Any, List, Optional, get_args + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams +from litellm.proxy.utils import get_instance_fn + +blue_color_code = "\033[94m" +reset_color_code = "\033[0m" + + +def initialize_callbacks_on_proxy( + value: Any, + premium_user: bool, + config_file_path: str, + litellm_settings: dict, +): + from litellm.proxy.proxy_server import prisma_client + + verbose_proxy_logger.debug( + f"{blue_color_code}initializing callbacks={value} on proxy{reset_color_code}" + ) + if isinstance(value, list): + imported_list: List[Any] = [] + known_compatible_callbacks = list( + get_args(litellm._custom_logger_compatible_callbacks_literal) + ) + + for callback in value: # ["presidio", ] + if isinstance(callback, str) and callback in known_compatible_callbacks: + imported_list.append(callback) + elif isinstance(callback, str) and callback == "otel": + from litellm.integrations.opentelemetry import OpenTelemetry + + open_telemetry_logger = OpenTelemetry() + + imported_list.append(open_telemetry_logger) + elif isinstance(callback, str) and callback == "presidio": + from litellm.proxy.hooks.presidio_pii_masking import ( + _OPTIONAL_PresidioPIIMasking, + ) + + pii_masking_object = _OPTIONAL_PresidioPIIMasking() + imported_list.append(pii_masking_object) + elif isinstance(callback, str) and callback == "llamaguard_moderations": + from enterprise.enterprise_hooks.llama_guard import ( + _ENTERPRISE_LlamaGuard, + ) + + if premium_user != True: + raise Exception( + "Trying to use Llama Guard" + + CommonProxyErrors.not_premium_user.value + ) + + llama_guard_object = _ENTERPRISE_LlamaGuard() + imported_list.append(llama_guard_object) + elif isinstance(callback, str) and callback == "hide_secrets": + from enterprise.enterprise_hooks.secret_detection import ( + _ENTERPRISE_SecretDetection, + ) + + if premium_user != True: + raise Exception( + "Trying to use secret hiding" + + CommonProxyErrors.not_premium_user.value + ) + + _secret_detection_object = _ENTERPRISE_SecretDetection() + imported_list.append(_secret_detection_object) + elif isinstance(callback, str) and callback == "openai_moderations": + from enterprise.enterprise_hooks.openai_moderation import ( + _ENTERPRISE_OpenAI_Moderation, + ) + + if premium_user != True: + raise Exception( + "Trying to use OpenAI Moderations Check" + + CommonProxyErrors.not_premium_user.value + ) + + openai_moderations_object = _ENTERPRISE_OpenAI_Moderation() + imported_list.append(openai_moderations_object) + elif isinstance(callback, str) and callback == "lakera_prompt_injection": + from enterprise.enterprise_hooks.lakera_ai import ( + _ENTERPRISE_lakeraAI_Moderation, + ) + + if premium_user != True: + raise Exception( + "Trying to use LakeraAI Prompt Injection" + + CommonProxyErrors.not_premium_user.value + ) + + lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation() + imported_list.append(lakera_moderations_object) + elif isinstance(callback, str) and callback == "google_text_moderation": + from enterprise.enterprise_hooks.google_text_moderation import ( + _ENTERPRISE_GoogleTextModeration, + ) + + if premium_user != True: + raise Exception( + "Trying to use Google Text Moderation" + + CommonProxyErrors.not_premium_user.value + ) + + google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration() + imported_list.append(google_text_moderation_obj) + elif isinstance(callback, str) and callback == "llmguard_moderations": + from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard + + if premium_user != True: + raise Exception( + "Trying to use Llm Guard" + + CommonProxyErrors.not_premium_user.value + ) + + llm_guard_moderation_obj = _ENTERPRISE_LLMGuard() + imported_list.append(llm_guard_moderation_obj) + elif isinstance(callback, str) and callback == "blocked_user_check": + from enterprise.enterprise_hooks.blocked_user_list import ( + _ENTERPRISE_BlockedUserList, + ) + + if premium_user != True: + raise Exception( + "Trying to use ENTERPRISE BlockedUser" + + CommonProxyErrors.not_premium_user.value + ) + + blocked_user_list = _ENTERPRISE_BlockedUserList( + prisma_client=prisma_client + ) + imported_list.append(blocked_user_list) + elif isinstance(callback, str) and callback == "banned_keywords": + from enterprise.enterprise_hooks.banned_keywords import ( + _ENTERPRISE_BannedKeywords, + ) + + if premium_user != True: + raise Exception( + "Trying to use ENTERPRISE BannedKeyword" + + CommonProxyErrors.not_premium_user.value + ) + + banned_keywords_obj = _ENTERPRISE_BannedKeywords() + imported_list.append(banned_keywords_obj) + elif isinstance(callback, str) and callback == "detect_prompt_injection": + from litellm.proxy.hooks.prompt_injection_detection import ( + _OPTIONAL_PromptInjectionDetection, + ) + + prompt_injection_params = None + if "prompt_injection_params" in litellm_settings: + prompt_injection_params_in_config = litellm_settings[ + "prompt_injection_params" + ] + prompt_injection_params = LiteLLMPromptInjectionParams( + **prompt_injection_params_in_config + ) + + prompt_injection_detection_obj = _OPTIONAL_PromptInjectionDetection( + prompt_injection_params=prompt_injection_params, + ) + imported_list.append(prompt_injection_detection_obj) + elif isinstance(callback, str) and callback == "batch_redis_requests": + from litellm.proxy.hooks.batch_redis_get import ( + _PROXY_BatchRedisRequests, + ) + + batch_redis_obj = _PROXY_BatchRedisRequests() + imported_list.append(batch_redis_obj) + elif isinstance(callback, str) and callback == "azure_content_safety": + from litellm.proxy.hooks.azure_content_safety import ( + _PROXY_AzureContentSafety, + ) + + azure_content_safety_params = litellm_settings[ + "azure_content_safety_params" + ] + for k, v in azure_content_safety_params.items(): + if ( + v is not None + and isinstance(v, str) + and v.startswith("os.environ/") + ): + azure_content_safety_params[k] = litellm.get_secret(v) + + azure_content_safety_obj = _PROXY_AzureContentSafety( + **azure_content_safety_params, + ) + imported_list.append(azure_content_safety_obj) + else: + verbose_proxy_logger.debug( + f"{blue_color_code} attempting to import custom calback={callback} {reset_color_code}" + ) + imported_list.append( + get_instance_fn( + value=callback, + config_file_path=config_file_path, + ) + ) + if isinstance(litellm.callbacks, list): + litellm.callbacks.extend(imported_list) + else: + litellm.callbacks = imported_list # type: ignore + else: + litellm.callbacks = [ + get_instance_fn( + value=value, + config_file_path=config_file_path, + ) + ] + verbose_proxy_logger.debug( + f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" + ) diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py new file mode 100644 index 000000000..1ff16b59e --- /dev/null +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -0,0 +1,56 @@ +import traceback +from typing import Dict, List + +from pydantic import BaseModel, RootModel + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy +from litellm.types.guardrails import GuardrailItem + + +def initialize_guardrails( + guardrails_config: list, + premium_user: bool, + config_file_path: str, + litellm_settings: dict, +): + try: + verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}") + + all_guardrails: List[GuardrailItem] = [] + for item in guardrails_config: + """ + one item looks like this: + + {'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True}} + """ + + for k, v in item.items(): + guardrail_item = GuardrailItem(**v, guardrail_name=k) + all_guardrails.append(guardrail_item) + + # set appropriate callbacks if they are default on + default_on_callbacks = [] + for guardrail in all_guardrails: + verbose_proxy_logger.debug(guardrail.guardrail_name) + verbose_proxy_logger.debug(guardrail.default_on) + + if guardrail.default_on is True: + # add these to litellm callbacks if they don't exist + for callback in guardrail.callbacks: + if callback not in litellm.callbacks: + default_on_callbacks.append(callback) + + if len(default_on_callbacks) > 0: + initialize_callbacks_on_proxy( + value=default_on_callbacks, + premium_user=premium_user, + config_file_path=config_file_path, + litellm_settings=litellm_settings, + ) + + except Exception as e: + verbose_proxy_logger.error(f"error initializing guardrails {str(e)}") + traceback.print_exc() + raise e diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 9f2324e51..f32e0ce2d 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -19,7 +19,6 @@ model_list: model: mistral/mistral-embed general_settings: - master_key: sk-1234 pass_through_endpoints: - path: "/v1/rerank" target: "https://api.cohere.com/v1/rerank" @@ -36,15 +35,13 @@ general_settings: LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" litellm_settings: - return_response_headers: true - success_callback: ["prometheus"] - callbacks: ["otel", "hide_secrets"] - failure_callback: ["prometheus"] - store_audit_logs: true - redact_messages_in_exceptions: True - enforced_params: - - user - - metadata - - metadata.generation_name + guardrails: + - prompt_injection: + callbacks: [lakera_prompt_injection, hide_secrets] + default_on: true + - hide_secrets: + callbacks: [hide_secrets] + default_on: true + diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1ca180722..9f745bb54 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -142,6 +142,8 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.caching_routes import router as caching_router from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy +from litellm.proxy.guardrails.init_guardrails import initialize_guardrails from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_endpoints._health_endpoints import router as health_router from litellm.proxy.hooks.prompt_injection_detection import ( @@ -1443,248 +1445,28 @@ class ProxyConfig: ) elif key == "cache" and value == False: pass - elif key == "callbacks": - if isinstance(value, list): - imported_list: List[Any] = [] - known_compatible_callbacks = list( - get_args( - litellm._custom_logger_compatible_callbacks_literal - ) + elif key == "guardrails": + if premium_user is not True: + raise ValueError( + "Trying to use `guardrails` on config.yaml " + + CommonProxyErrors.not_premium_user.value ) - for callback in value: # ["presidio", ] - if ( - isinstance(callback, str) - and callback in known_compatible_callbacks - ): - imported_list.append(callback) - elif isinstance(callback, str) and callback == "otel": - from litellm.integrations.opentelemetry import ( - OpenTelemetry, - ) - open_telemetry_logger = OpenTelemetry() - - imported_list.append(open_telemetry_logger) - elif isinstance(callback, str) and callback == "presidio": - from litellm.proxy.hooks.presidio_pii_masking import ( - _OPTIONAL_PresidioPIIMasking, - ) - - pii_masking_object = _OPTIONAL_PresidioPIIMasking() - imported_list.append(pii_masking_object) - elif ( - isinstance(callback, str) - and callback == "llamaguard_moderations" - ): - from enterprise.enterprise_hooks.llama_guard import ( - _ENTERPRISE_LlamaGuard, - ) - - if premium_user != True: - raise Exception( - "Trying to use Llama Guard" - + CommonProxyErrors.not_premium_user.value - ) - - llama_guard_object = _ENTERPRISE_LlamaGuard() - imported_list.append(llama_guard_object) - elif ( - isinstance(callback, str) and callback == "hide_secrets" - ): - from enterprise.enterprise_hooks.secret_detection import ( - _ENTERPRISE_SecretDetection, - ) - - if premium_user != True: - raise Exception( - "Trying to use secret hiding" - + CommonProxyErrors.not_premium_user.value - ) - - _secret_detection_object = _ENTERPRISE_SecretDetection() - imported_list.append(_secret_detection_object) - elif ( - isinstance(callback, str) - and callback == "openai_moderations" - ): - from enterprise.enterprise_hooks.openai_moderation import ( - _ENTERPRISE_OpenAI_Moderation, - ) - - if premium_user != True: - raise Exception( - "Trying to use OpenAI Moderations Check" - + CommonProxyErrors.not_premium_user.value - ) - - openai_moderations_object = ( - _ENTERPRISE_OpenAI_Moderation() - ) - imported_list.append(openai_moderations_object) - elif ( - isinstance(callback, str) - and callback == "lakera_prompt_injection" - ): - from enterprise.enterprise_hooks.lakera_ai import ( - _ENTERPRISE_lakeraAI_Moderation, - ) - - if premium_user != True: - raise Exception( - "Trying to use LakeraAI Prompt Injection" - + CommonProxyErrors.not_premium_user.value - ) - - lakera_moderations_object = ( - _ENTERPRISE_lakeraAI_Moderation() - ) - imported_list.append(lakera_moderations_object) - elif ( - isinstance(callback, str) - and callback == "google_text_moderation" - ): - from enterprise.enterprise_hooks.google_text_moderation import ( - _ENTERPRISE_GoogleTextModeration, - ) - - if premium_user != True: - raise Exception( - "Trying to use Google Text Moderation" - + CommonProxyErrors.not_premium_user.value - ) - - google_text_moderation_obj = ( - _ENTERPRISE_GoogleTextModeration() - ) - imported_list.append(google_text_moderation_obj) - elif ( - isinstance(callback, str) - and callback == "llmguard_moderations" - ): - from enterprise.enterprise_hooks.llm_guard import ( - _ENTERPRISE_LLMGuard, - ) - - if premium_user != True: - raise Exception( - "Trying to use Llm Guard" - + CommonProxyErrors.not_premium_user.value - ) - - llm_guard_moderation_obj = _ENTERPRISE_LLMGuard() - imported_list.append(llm_guard_moderation_obj) - elif ( - isinstance(callback, str) - and callback == "blocked_user_check" - ): - from enterprise.enterprise_hooks.blocked_user_list import ( - _ENTERPRISE_BlockedUserList, - ) - - if premium_user != True: - raise Exception( - "Trying to use ENTERPRISE BlockedUser" - + CommonProxyErrors.not_premium_user.value - ) - - blocked_user_list = _ENTERPRISE_BlockedUserList( - prisma_client=prisma_client - ) - imported_list.append(blocked_user_list) - elif ( - isinstance(callback, str) - and callback == "banned_keywords" - ): - from enterprise.enterprise_hooks.banned_keywords import ( - _ENTERPRISE_BannedKeywords, - ) - - if premium_user != True: - raise Exception( - "Trying to use ENTERPRISE BannedKeyword" - + CommonProxyErrors.not_premium_user.value - ) - - banned_keywords_obj = _ENTERPRISE_BannedKeywords() - imported_list.append(banned_keywords_obj) - elif ( - isinstance(callback, str) - and callback == "detect_prompt_injection" - ): - from litellm.proxy.hooks.prompt_injection_detection import ( - _OPTIONAL_PromptInjectionDetection, - ) - - prompt_injection_params = None - if "prompt_injection_params" in litellm_settings: - prompt_injection_params_in_config = ( - litellm_settings["prompt_injection_params"] - ) - prompt_injection_params = ( - LiteLLMPromptInjectionParams( - **prompt_injection_params_in_config - ) - ) - - prompt_injection_detection_obj = ( - _OPTIONAL_PromptInjectionDetection( - prompt_injection_params=prompt_injection_params, - ) - ) - imported_list.append(prompt_injection_detection_obj) - elif ( - isinstance(callback, str) - and callback == "batch_redis_requests" - ): - from litellm.proxy.hooks.batch_redis_get import ( - _PROXY_BatchRedisRequests, - ) - - batch_redis_obj = _PROXY_BatchRedisRequests() - imported_list.append(batch_redis_obj) - elif ( - isinstance(callback, str) - and callback == "azure_content_safety" - ): - from litellm.proxy.hooks.azure_content_safety import ( - _PROXY_AzureContentSafety, - ) - - azure_content_safety_params = litellm_settings[ - "azure_content_safety_params" - ] - for k, v in azure_content_safety_params.items(): - if ( - v is not None - and isinstance(v, str) - and v.startswith("os.environ/") - ): - azure_content_safety_params[k] = ( - litellm.get_secret(v) - ) - - azure_content_safety_obj = _PROXY_AzureContentSafety( - **azure_content_safety_params, - ) - imported_list.append(azure_content_safety_obj) - else: - imported_list.append( - get_instance_fn( - value=callback, - config_file_path=config_file_path, - ) - ) - litellm.callbacks = imported_list # type: ignore - else: - litellm.callbacks = [ - get_instance_fn( - value=value, - config_file_path=config_file_path, - ) - ] - verbose_proxy_logger.debug( - f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" + initialize_guardrails( + guardrails_config=value, + premium_user=premium_user, + config_file_path=config_file_path, + litellm_settings=litellm_settings, ) + elif key == "callbacks": + + initialize_callbacks_on_proxy( + value=value, + premium_user=premium_user, + config_file_path=config_file_path, + litellm_settings=litellm_settings, + ) + elif key == "post_call_rules": litellm.post_call_rules = [ get_instance_fn(value=value, config_file_path=config_file_path)