From 1028be63080ff77a6dce2a7b472ce1fa7dd4a6c8 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 3 Jul 2024 16:34:23 -0700 Subject: [PATCH] feat- control lakera ai per llm call --- enterprise/enterprise_hooks/lakera_ai.py | 30 ++++-------- litellm/proxy/guardrails/guardrail_helpers.py | 46 +++++++++++++++++++ litellm/proxy/guardrails/init_guardrails.py | 8 +++- 3 files changed, 62 insertions(+), 22 deletions(-) create mode 100644 litellm/proxy/guardrails/guardrail_helpers.py diff --git a/enterprise/enterprise_hooks/lakera_ai.py b/enterprise/enterprise_hooks/lakera_ai.py index 3d874da8d..642589a25 100644 --- a/enterprise/enterprise_hooks/lakera_ai.py +++ b/enterprise/enterprise_hooks/lakera_ai.py @@ -17,12 +17,9 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException from litellm._logging import verbose_proxy_logger -from litellm.utils import ( - ModelResponse, - EmbeddingResponse, - ImageResponse, - StreamingChoices, -) +from litellm.proxy.guardrails.init_guardrails import all_guardrails +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata + from datetime import datetime import aiohttp, asyncio from litellm._logging import verbose_proxy_logger @@ -43,19 +40,6 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): self.lakera_api_key = os.environ["LAKERA_API_KEY"] pass - async def should_proceed(self, data: dict) -> bool: - """ - checks if this guardrail should be applied to this call - """ - if "metadata" in data and isinstance(data["metadata"], dict): - if "guardrails" in data["metadata"]: - # if guardrails passed in metadata -> this is a list of guardrails the user wants to run on the call - if GUARDRAIL_NAME not in data["metadata"]["guardrails"]: - return False - - # in all other cases it should proceed - return True - #### CALL HOOKS - proxy only #### async def async_moderation_hook( ### 👈 KEY CHANGE ### @@ -65,7 +49,13 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): call_type: Literal["completion", "embeddings", "image_generation"], ): - if await self.should_proceed(data=data) is False: + if ( + await should_proceed_based_on_metadata( + data=data, + guardrail_name=GUARDRAIL_NAME, + ) + is False + ): return if "messages" in data and isinstance(data["messages"], list): diff --git a/litellm/proxy/guardrails/guardrail_helpers.py b/litellm/proxy/guardrails/guardrail_helpers.py new file mode 100644 index 000000000..39c9a9831 --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_helpers.py @@ -0,0 +1,46 @@ +from litellm._logging import verbose_proxy_logger +from litellm.proxy.guardrails.init_guardrails import guardrail_name_config_map +from litellm.types.guardrails import * + + +async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> bool: + """ + checks if this guardrail should be applied to this call + """ + if "metadata" in data and isinstance(data["metadata"], dict): + if "guardrails" in data["metadata"]: + # expect users to pass + # guardrails: { prompt_injection: true, rail_2: false } + request_guardrails = data["metadata"]["guardrails"] + verbose_proxy_logger.debug( + "Guardrails %s passed in request - checking which to apply", + request_guardrails, + ) + + requested_callback_names = [] + + # get guardrail configs from `init_guardrails.py` + # for all requested guardrails -> get their associated callbacks + for _guardrail_name, should_run in request_guardrails.items(): + if should_run is False: + verbose_proxy_logger.debug( + "Guardrail %s skipped because request set to False", + _guardrail_name, + ) + continue + + # lookup the guardrail in guardrail_name_config_map + guardrail_item: GuardrailItem = guardrail_name_config_map[ + _guardrail_name + ] + + guardrail_callbacks = guardrail_item.callbacks + requested_callback_names.extend(guardrail_callbacks) + + verbose_proxy_logger.debug( + "requested_callback_names %s", requested_callback_names + ) + if guardrail_name in requested_callback_names: + return True + + return False diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 4cf451019..9c9fde533 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -8,6 +8,10 @@ from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy from litellm.types.guardrails import GuardrailItem +all_guardrails: List[GuardrailItem] = [] + +guardrail_name_config_map: Dict[str, GuardrailItem] = {} + def initialize_guardrails( guardrails_config: list, @@ -17,8 +21,7 @@ def initialize_guardrails( ): try: verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}") - - all_guardrails: List[GuardrailItem] = [] + global all_guardrails for item in guardrails_config: """ one item looks like this: @@ -29,6 +32,7 @@ def initialize_guardrails( for k, v in item.items(): guardrail_item = GuardrailItem(**v, guardrail_name=k) all_guardrails.append(guardrail_item) + guardrail_name_config_map[k] = guardrail_item # set appropriate callbacks if they are default on default_on_callbacks = set()