feat- control lakera ai per llm call

This commit is contained in:
Ishaan Jaff 2024-07-03 16:34:23 -07:00
parent 228997b074
commit 1028be6308
3 changed files with 62 additions and 22 deletions

View file

@ -17,12 +17,9 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
StreamingChoices,
)
from litellm.proxy.guardrails.init_guardrails import all_guardrails
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from datetime import datetime
import aiohttp, asyncio
from litellm._logging import verbose_proxy_logger
@ -43,19 +40,6 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
self.lakera_api_key = os.environ["LAKERA_API_KEY"]
pass
async def should_proceed(self, data: dict) -> bool:
"""
checks if this guardrail should be applied to this call
"""
if "metadata" in data and isinstance(data["metadata"], dict):
if "guardrails" in data["metadata"]:
# if guardrails passed in metadata -> this is a list of guardrails the user wants to run on the call
if GUARDRAIL_NAME not in data["metadata"]["guardrails"]:
return False
# in all other cases it should proceed
return True
#### CALL HOOKS - proxy only ####
async def async_moderation_hook( ### 👈 KEY CHANGE ###
@ -65,7 +49,13 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
call_type: Literal["completion", "embeddings", "image_generation"],
):
if await self.should_proceed(data=data) is False:
if (
await should_proceed_based_on_metadata(
data=data,
guardrail_name=GUARDRAIL_NAME,
)
is False
):
return
if "messages" in data and isinstance(data["messages"], list):

View file

@ -0,0 +1,46 @@
from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.init_guardrails import guardrail_name_config_map
from litellm.types.guardrails import *
async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> bool:
"""
checks if this guardrail should be applied to this call
"""
if "metadata" in data and isinstance(data["metadata"], dict):
if "guardrails" in data["metadata"]:
# expect users to pass
# guardrails: { prompt_injection: true, rail_2: false }
request_guardrails = data["metadata"]["guardrails"]
verbose_proxy_logger.debug(
"Guardrails %s passed in request - checking which to apply",
request_guardrails,
)
requested_callback_names = []
# get guardrail configs from `init_guardrails.py`
# for all requested guardrails -> get their associated callbacks
for _guardrail_name, should_run in request_guardrails.items():
if should_run is False:
verbose_proxy_logger.debug(
"Guardrail %s skipped because request set to False",
_guardrail_name,
)
continue
# lookup the guardrail in guardrail_name_config_map
guardrail_item: GuardrailItem = guardrail_name_config_map[
_guardrail_name
]
guardrail_callbacks = guardrail_item.callbacks
requested_callback_names.extend(guardrail_callbacks)
verbose_proxy_logger.debug(
"requested_callback_names %s", requested_callback_names
)
if guardrail_name in requested_callback_names:
return True
return False

View file

@ -8,6 +8,10 @@ from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem
all_guardrails: List[GuardrailItem] = []
guardrail_name_config_map: Dict[str, GuardrailItem] = {}
def initialize_guardrails(
guardrails_config: list,
@ -17,8 +21,7 @@ def initialize_guardrails(
):
try:
verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}")
all_guardrails: List[GuardrailItem] = []
global all_guardrails
for item in guardrails_config:
"""
one item looks like this:
@ -29,6 +32,7 @@ def initialize_guardrails(
for k, v in item.items():
guardrail_item = GuardrailItem(**v, guardrail_name=k)
all_guardrails.append(guardrail_item)
guardrail_name_config_map[k] = guardrail_item
# set appropriate callbacks if they are default on
default_on_callbacks = set()