forked from phoenix/litellm-mirror
feat(proxy_server.py): enable admin to set banned keywords on proxy
This commit is contained in:
parent
6ad450396b
commit
acae98fd50
4 changed files with 177 additions and 0 deletions
103
enterprise/enterprise_hooks/banned_keywords.py
Normal file
103
enterprise/enterprise_hooks/banned_keywords.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
# +------------------------------+
|
||||
#
|
||||
# Banned Keywords
|
||||
#
|
||||
# +------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
## Reject a call / response if it contains certain keywords
|
||||
|
||||
|
||||
from typing import Optional, Literal
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from fastapi import HTTPException
|
||||
import json, traceback
|
||||
|
||||
|
||||
class _ENTERPRISE_BannedKeywords(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
banned_keywords_list = litellm.banned_keywords_list
|
||||
|
||||
if banned_keywords_list is None:
|
||||
raise Exception(
|
||||
"`banned_keywords_list` can either be a list or filepath. None set."
|
||||
)
|
||||
|
||||
if isinstance(banned_keywords_list, list):
|
||||
self.banned_keywords_list = banned_keywords_list
|
||||
|
||||
if isinstance(banned_keywords_list, str): # assume it's a filepath
|
||||
try:
|
||||
with open(banned_keywords_list, "r") as file:
|
||||
data = file.read()
|
||||
self.banned_keywords_list = data.split("\n")
|
||||
except FileNotFoundError:
|
||||
raise Exception(
|
||||
f"File not found. banned_keywords_list={banned_keywords_list}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred: {str(e)}, banned_keywords_list={banned_keywords_list}"
|
||||
)
|
||||
|
||||
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
|
||||
if level == "INFO":
|
||||
verbose_proxy_logger.info(print_statement)
|
||||
elif level == "DEBUG":
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
def test_violation(self, test_str: str):
|
||||
for word in self.banned_keywords_list:
|
||||
if word in test_str.lower():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": f"Keyword banned. Keyword={word}"},
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||
):
|
||||
try:
|
||||
"""
|
||||
- check if user id part of call
|
||||
- check if user id part of blocked list
|
||||
"""
|
||||
self.print_verbose(f"Inside Banned Keyword List Pre-Call Hook")
|
||||
if call_type == "completion" and "messages" in data:
|
||||
for m in data["messages"]:
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
self.test_violation(test_str=m["content"])
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||
response.choices[0], litellm.utils.Choices
|
||||
):
|
||||
for word in self.banned_keywords_list:
|
||||
self.test_violation(test_str=response.choices[0].message.content)
|
||||
|
||||
async def async_post_call_streaming_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: str,
|
||||
):
|
||||
self.test_violation(test_str=response)
|
|
@ -61,6 +61,7 @@ presidio_ad_hoc_recognizers: Optional[str] = None
|
|||
google_moderation_confidence_threshold: Optional[float] = None
|
||||
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||
blocked_user_list: Optional[Union[str, List]] = None
|
||||
banned_keywords_list: Optional[Union[str, List]] = None
|
||||
##################
|
||||
logging: bool = True
|
||||
caching: bool = (
|
||||
|
|
|
@ -1489,6 +1489,16 @@ class ProxyConfig:
|
|||
|
||||
blocked_user_list = _ENTERPRISE_BlockedUserList()
|
||||
imported_list.append(blocked_user_list)
|
||||
elif (
|
||||
isinstance(callback, str)
|
||||
and callback == "banned_keywords"
|
||||
):
|
||||
from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
|
||||
_ENTERPRISE_BannedKeywords,
|
||||
)
|
||||
|
||||
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
|
||||
imported_list.append(banned_keywords_obj)
|
||||
else:
|
||||
imported_list.append(
|
||||
get_instance_fn(
|
||||
|
|
63
litellm/tests/test_banned_keyword_list.py
Normal file
63
litellm/tests/test_banned_keyword_list.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# What is this?
|
||||
## This tests the blocked user pre call hook for the proxy server
|
||||
|
||||
|
||||
import sys, os, asyncio, time, random
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
|
||||
_ENTERPRISE_BannedKeywords,
|
||||
)
|
||||
from litellm import Router, mock_completion
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banned_keywords_check():
|
||||
"""
|
||||
- Set some banned keywords as a litellm module value
|
||||
- Test to see if a call with banned keywords is made, an error is raised
|
||||
- Test to see if a call without banned keywords is made it passes
|
||||
"""
|
||||
litellm.banned_keywords_list = ["hello"]
|
||||
|
||||
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
|
||||
|
||||
_api_key = "sk-12345"
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||
local_cache = DualCache()
|
||||
|
||||
## Case 1: blocked user id passed
|
||||
try:
|
||||
await banned_keywords_obj.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
call_type="completion",
|
||||
data={"messages": [{"role": "user", "content": "Hello world"}]},
|
||||
)
|
||||
pytest.fail(f"Expected call to fail")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
## Case 2: normal user id passed
|
||||
try:
|
||||
await banned_keywords_obj.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
call_type="completion",
|
||||
data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]},
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error occurred - {str(e)}")
|
Loading…
Add table
Add a link
Reference in a new issue