diff --git a/enterprise/enterprise_hooks/banned_keywords.py b/enterprise/enterprise_hooks/banned_keywords.py new file mode 100644 index 000000000..acd390d79 --- /dev/null +++ b/enterprise/enterprise_hooks/banned_keywords.py @@ -0,0 +1,103 @@ +# +------------------------------+ +# +# Banned Keywords +# +# +------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan +## Reject a call / response if it contains certain keywords + + +from typing import Optional, Literal +import litellm +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from litellm._logging import verbose_proxy_logger +from fastapi import HTTPException +import json, traceback + + +class _ENTERPRISE_BannedKeywords(CustomLogger): + # Class variables or attributes + def __init__(self): + banned_keywords_list = litellm.banned_keywords_list + + if banned_keywords_list is None: + raise Exception( + "`banned_keywords_list` can either be a list or filepath. None set." + ) + + if isinstance(banned_keywords_list, list): + self.banned_keywords_list = banned_keywords_list + + if isinstance(banned_keywords_list, str): # assume it's a filepath + try: + with open(banned_keywords_list, "r") as file: + data = file.read() + self.banned_keywords_list = data.split("\n") + except FileNotFoundError: + raise Exception( + f"File not found. banned_keywords_list={banned_keywords_list}" + ) + except Exception as e: + raise Exception( + f"An error occurred: {str(e)}, banned_keywords_list={banned_keywords_list}" + ) + + def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): + if level == "INFO": + verbose_proxy_logger.info(print_statement) + elif level == "DEBUG": + verbose_proxy_logger.debug(print_statement) + + if litellm.set_verbose is True: + print(print_statement) # noqa + + def test_violation(self, test_str: str): + for word in self.banned_keywords_list: + if word in test_str.lower(): + raise HTTPException( + status_code=400, + detail={"error": f"Keyword banned. Keyword={word}"}, + ) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, # "completion", "embeddings", "image_generation", "moderation" + ): + try: + """ + - check if user id part of call + - check if user id part of blocked list + """ + self.print_verbose(f"Inside Banned Keyword List Pre-Call Hook") + if call_type == "completion" and "messages" in data: + for m in data["messages"]: + if "content" in m and isinstance(m["content"], str): + self.test_violation(test_str=m["content"]) + + except HTTPException as e: + raise e + except Exception as e: + traceback.print_exc() + + async def async_post_call_success_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + if isinstance(response, litellm.ModelResponse) and isinstance( + response.choices[0], litellm.utils.Choices + ): + for word in self.banned_keywords_list: + self.test_violation(test_str=response.choices[0].message.content) + + async def async_post_call_streaming_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: str, + ): + self.test_violation(test_str=response) diff --git a/litellm/__init__.py b/litellm/__init__.py index 9b3107b2d..ac657fa99 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -61,6 +61,7 @@ presidio_ad_hoc_recognizers: Optional[str] = None google_moderation_confidence_threshold: Optional[float] = None llamaguard_unsafe_content_categories: Optional[str] = None blocked_user_list: Optional[Union[str, List]] = None +banned_keywords_list: Optional[Union[str, List]] = None ################## logging: bool = True caching: bool = ( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 541e1af00..030af777a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1489,6 +1489,16 @@ class ProxyConfig: blocked_user_list = _ENTERPRISE_BlockedUserList() imported_list.append(blocked_user_list) + elif ( + isinstance(callback, str) + and callback == "banned_keywords" + ): + from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import ( + _ENTERPRISE_BannedKeywords, + ) + + banned_keywords_obj = _ENTERPRISE_BannedKeywords() + imported_list.append(banned_keywords_obj) else: imported_list.append( get_instance_fn( diff --git a/litellm/tests/test_banned_keyword_list.py b/litellm/tests/test_banned_keyword_list.py new file mode 100644 index 000000000..f8804df9a --- /dev/null +++ b/litellm/tests/test_banned_keyword_list.py @@ -0,0 +1,63 @@ +# What is this? +## This tests the blocked user pre call hook for the proxy server + + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import ( + _ENTERPRISE_BannedKeywords, +) +from litellm import Router, mock_completion +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache + + +@pytest.mark.asyncio +async def test_banned_keywords_check(): + """ + - Set some banned keywords as a litellm module value + - Test to see if a call with banned keywords is made, an error is raised + - Test to see if a call without banned keywords is made it passes + """ + litellm.banned_keywords_list = ["hello"] + + banned_keywords_obj = _ENTERPRISE_BannedKeywords() + + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) + local_cache = DualCache() + + ## Case 1: blocked user id passed + try: + await banned_keywords_obj.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + call_type="completion", + data={"messages": [{"role": "user", "content": "Hello world"}]}, + ) + pytest.fail(f"Expected call to fail") + except Exception as e: + pass + + ## Case 2: normal user id passed + try: + await banned_keywords_obj.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + call_type="completion", + data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]}, + ) + except Exception as e: + pytest.fail(f"An error occurred - {str(e)}")