diff --git a/docs/my-website/docs/proxy/guardrails.md b/docs/my-website/docs/proxy/guardrails.md index 698e97f9a..451ca8ab5 100644 --- a/docs/my-website/docs/proxy/guardrails.md +++ b/docs/my-website/docs/proxy/guardrails.md @@ -1,18 +1,10 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# 🛡️ Guardrails +# 🛡️ [Beta] Guardrails Setup Prompt Injection Detection, Secret Detection on LiteLLM Proxy -:::info - -✨ Enterprise Only Feature - -Schedule a meeting with us to get an Enterprise License 👉 Talk to founders [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) - -::: - ## Quick Start ### 1. Setup guardrails on litellm proxy config.yaml diff --git a/enterprise/enterprise_hooks/aporio_ai.py b/enterprise/enterprise_hooks/aporio_ai.py index b0b1b50c9..9929760b4 100644 --- a/enterprise/enterprise_hooks/aporio_ai.py +++ b/enterprise/enterprise_hooks/aporio_ai.py @@ -15,7 +15,7 @@ from typing import Optional, Literal, Union, Any import litellm, traceback, sys, uuid from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth -from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.custom_guardrail import CustomGuardrail from fastapi import HTTPException from litellm._logging import verbose_proxy_logger from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata @@ -29,19 +29,25 @@ from litellm._logging import verbose_proxy_logger from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler import httpx import json +from litellm.types.guardrails import GuardrailEventHooks litellm.set_verbose = True GUARDRAIL_NAME = "aporio" -class _ENTERPRISE_Aporio(CustomLogger): - def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None): +class _ENTERPRISE_Aporio(CustomGuardrail): + def __init__( + self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs + ): self.async_handler = AsyncHTTPHandler( timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) self.aporio_api_key = api_key or os.environ["APORIO_API_KEY"] self.aporio_api_base = api_base or os.environ["APORIO_API_BASE"] + self.event_hook: GuardrailEventHooks + + super().__init__(**kwargs) #### CALL HOOKS - proxy only #### def transform_messages(self, messages: List[dict]) -> List[dict]: @@ -140,10 +146,15 @@ class _ENTERPRISE_Aporio(CustomLogger): from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) + from litellm.types.guardrails import GuardrailEventHooks """ Use this for the post call moderation with Guardrails """ + event_type: GuardrailEventHooks = GuardrailEventHooks.post_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + response_str: Optional[str] = convert_litellm_response_object_to_str(response) if response_str is not None: await self.make_aporia_api_request( @@ -151,7 +162,7 @@ class _ENTERPRISE_Aporio(CustomLogger): ) add_guardrail_to_applied_guardrails_header( - request_data=data, guardrail_name=f"post_call_{GUARDRAIL_NAME}" + request_data=data, guardrail_name=self.guardrail_name ) pass @@ -165,7 +176,13 @@ class _ENTERPRISE_Aporio(CustomLogger): from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) + from litellm.types.guardrails import GuardrailEventHooks + event_type: GuardrailEventHooks = GuardrailEventHooks.during_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + # old implementation - backwards compatibility if ( await should_proceed_based_on_metadata( data=data, @@ -182,7 +199,7 @@ class _ENTERPRISE_Aporio(CustomLogger): if new_messages is not None: await self.make_aporia_api_request(new_messages=new_messages) add_guardrail_to_applied_guardrails_header( - request_data=data, guardrail_name=f"during_call_{GUARDRAIL_NAME}" + request_data=data, guardrail_name=self.guardrail_name ) else: verbose_proxy_logger.warning( diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py new file mode 100644 index 000000000..a3ac2ea86 --- /dev/null +++ b/litellm/integrations/custom_guardrail.py @@ -0,0 +1,32 @@ +from typing import Literal + +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.guardrails import GuardrailEventHooks + + +class CustomGuardrail(CustomLogger): + + def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs): + self.guardrail_name = guardrail_name + self.event_hook: GuardrailEventHooks = event_hook + super().__init__(**kwargs) + + def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: + verbose_logger.debug( + "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s", + self.guardrail_name, + event_type, + self.event_hook, + ) + + metadata = data.get("metadata") or {} + requested_guardrails = metadata.get("guardrails") or [] + + if self.guardrail_name not in requested_guardrails: + return False + + if self.event_hook != event_type: + return False + + return True diff --git a/litellm/proxy/guardrails/guardrail_helpers.py b/litellm/proxy/guardrails/guardrail_helpers.py index e0a5f1eb3..a57b965c8 100644 --- a/litellm/proxy/guardrails/guardrail_helpers.py +++ b/litellm/proxy/guardrails/guardrail_helpers.py @@ -37,32 +37,35 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b requested_callback_names = [] - # get guardrail configs from `init_guardrails.py` - # for all requested guardrails -> get their associated callbacks - for _guardrail_name, should_run in request_guardrails.items(): - if should_run is False: - verbose_proxy_logger.debug( - "Guardrail %s skipped because request set to False", - _guardrail_name, - ) - continue + # v1 implementation of this + if isinstance(request_guardrails, dict): - # lookup the guardrail in guardrail_name_config_map - guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[ - _guardrail_name - ] + # get guardrail configs from `init_guardrails.py` + # for all requested guardrails -> get their associated callbacks + for _guardrail_name, should_run in request_guardrails.items(): + if should_run is False: + verbose_proxy_logger.debug( + "Guardrail %s skipped because request set to False", + _guardrail_name, + ) + continue - guardrail_callbacks = guardrail_item.callbacks - requested_callback_names.extend(guardrail_callbacks) + # lookup the guardrail in guardrail_name_config_map + guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[ + _guardrail_name + ] - verbose_proxy_logger.debug( - "requested_callback_names %s", requested_callback_names - ) - if guardrail_name in requested_callback_names: - return True + guardrail_callbacks = guardrail_item.callbacks + requested_callback_names.extend(guardrail_callbacks) - # Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } } - return False + verbose_proxy_logger.debug( + "requested_callback_names %s", requested_callback_names + ) + if guardrail_name in requested_callback_names: + return True + + # Do no proceeed if - "metadata": { "guardrails": { "lakera_prompt_injection": false } } + return False return True diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 8bf476311..b1855033d 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -1,12 +1,20 @@ import traceback -from typing import Dict, List +from typing import Dict, List, Literal from pydantic import BaseModel, RootModel import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy -from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec + +# v2 implementation +from litellm.types.guardrails import ( + Guardrail, + GuardrailItem, + GuardrailItemSpec, + LitellmParams, + guardrailConfig, +) all_guardrails: List[GuardrailItem] = [] @@ -66,3 +74,70 @@ def initialize_guardrails( "error initializing guardrails {}".format(str(e)) ) raise e + + +""" +Map guardrail_name: , , during_call + +""" + + +def init_guardrails_v2(all_guardrails: dict): + # Convert the loaded data to the TypedDict structure + guardrail_list = [] + + # Parse each guardrail and replace environment variables + for guardrail in all_guardrails: + + # Init litellm params for guardrail + litellm_params_data = guardrail["litellm_params"] + verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data) + litellm_params = LitellmParams( + guardrail=litellm_params_data["guardrail"], + mode=litellm_params_data["mode"], + api_key=litellm_params_data["api_key"], + api_base=litellm_params_data["api_base"], + ) + + if litellm_params["api_key"]: + if litellm_params["api_key"].startswith("os.environ/"): + litellm_params["api_key"] = litellm.get_secret( + litellm_params["api_key"] + ) + + if litellm_params["api_base"]: + if litellm_params["api_base"].startswith("os.environ/"): + litellm_params["api_base"] = litellm.get_secret( + litellm_params["api_base"] + ) + + # Init guardrail CustomLoggerClass + if litellm_params["guardrail"] == "aporia": + from litellm.proxy.enterprise.enterprise_hooks.aporio_ai import ( + _ENTERPRISE_Aporio, + ) + + _aporia_callback = _ENTERPRISE_Aporio( + api_base=litellm_params["api_base"], + api_key=litellm_params["api_key"], + guardrail_name=guardrail["guardrail_name"], + event_hook=litellm_params["mode"], + ) + litellm.callbacks.append(_aporia_callback) # type: ignore + elif litellm_params["guardrail"] == "lakera": + from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( + _ENTERPRISE_lakeraAI_Moderation, + ) + + _lakera_callback = _ENTERPRISE_lakeraAI_Moderation() + litellm.callbacks.append(_lakera_callback) # type: ignore + + parsed_guardrail = Guardrail( + guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params + ) + + guardrail_list.append(parsed_guardrail) + guardrail_name = guardrail["guardrail_name"] + + # pretty print guardrail_list in green + print(f"\nGuardrail List:{guardrail_list}\n") # noqa diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index dd39efd6b..b38a326f3 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -308,9 +308,20 @@ async def add_litellm_data_to_request( for k, v in callback_settings_obj.callback_vars.items(): data[k] = v + # Guardrails + move_guardrails_to_metadata( + data=data, _metadata_variable_name=_metadata_variable_name + ) + return data +def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str): + if "guardrails" in data: + data[_metadata_variable_name]["guardrails"] = data["guardrails"] + del data["guardrails"] + + def add_provider_specific_headers_to_request( data: dict, headers: dict, diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 6027b8b1c..e36b555a9 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -5,14 +5,15 @@ model_list: api_key: os.environ/OPENAI_API_KEY guardrails: - - guardrail_name: prompt_injection_detection + - guardrail_name: "aporia-pre-guard" litellm_params: - guardrail_name: openai/gpt-3.5-turbo - api_key: os.environ/OPENAI_API_KEY - api_base: os.environ/OPENAI_API_BASE - - guardrail_name: prompt_injection_detection + guardrail: aporia # supported values: "aporia", "bedrock", "lakera" + mode: "post_call" + api_key: os.environ/APORIA_API_KEY_1 + api_base: os.environ/APORIA_API_BASE_1 + - guardrail_name: "aporia-post-guard" litellm_params: - guardrail_name: openai/gpt-3.5-turbo - api_key: os.environ/OPENAI_API_KEY - api_base: os.environ/OPENAI_API_BASE - + guardrail: aporia # supported values: "aporia", "bedrock", "lakera" + mode: "post_call" + api_key: os.environ/APORIA_API_KEY_2 + api_base: os.environ/APORIA_API_BASE_2 \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a759dd973..199bf53a7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -169,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import ( ) from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config -from litellm.proxy.guardrails.init_guardrails import initialize_guardrails +from litellm.proxy.guardrails.init_guardrails import ( + init_guardrails_v2, + initialize_guardrails, +) from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_endpoints._health_endpoints import router as health_router from litellm.proxy.hooks.prompt_injection_detection import ( @@ -1939,6 +1942,11 @@ class ProxyConfig: async_only_mode=True # only init async clients ), ) # type:ignore + + # Guardrail settings + guardrails_v2 = config.get("guardrails", None) + if guardrails_v2: + init_guardrails_v2(all_guardrails=guardrails_v2) return router, router.get_model_list(), general_settings def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo: diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 0296d8de4..cd9f76f17 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TypedDict from pydantic import BaseModel, ConfigDict from typing_extensions import Required, TypedDict @@ -63,3 +63,26 @@ class GuardrailItem(BaseModel): enabled_roles=enabled_roles, callback_args=callback_args, ) + + +# Define the TypedDicts +class LitellmParams(TypedDict): + guardrail: str + mode: str + api_key: str + api_base: Optional[str] + + +class Guardrail(TypedDict): + guardrail_name: str + litellm_params: LitellmParams + + +class guardrailConfig(TypedDict): + guardrails: List[Guardrail] + + +class GuardrailEventHooks(str, Enum): + pre_call = "pre_call" + post_call = "post_call" + during_call = "during_call"