From 028f455ad040ddc2f33dcad838a7e968c3d55192 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 22 Feb 2024 17:51:31 -0800 Subject: [PATCH] feat(proxy_server.py): add support for blocked user lists (enterprise-only) --- .../enterprise_hooks/blocked_user_list.py | 80 +++++++++++++++++++ litellm/__init__.py | 1 + litellm/proxy/proxy_server.py | 10 +++ litellm/tests/test_blocked_user_list.py | 63 +++++++++++++++ 4 files changed, 154 insertions(+) create mode 100644 enterprise/enterprise_hooks/blocked_user_list.py create mode 100644 litellm/tests/test_blocked_user_list.py diff --git a/enterprise/enterprise_hooks/blocked_user_list.py b/enterprise/enterprise_hooks/blocked_user_list.py new file mode 100644 index 000000000..26a1bd9f7 --- /dev/null +++ b/enterprise/enterprise_hooks/blocked_user_list.py @@ -0,0 +1,80 @@ +# +------------------------------+ +# +# Blocked User List +# +# +------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan +## This accepts a list of user id's for whom calls will be rejected + + +from typing import Optional, Literal +import litellm +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from litellm._logging import verbose_proxy_logger +from fastapi import HTTPException +import json, traceback + + +class _ENTERPRISE_BlockedUserList(CustomLogger): + # Class variables or attributes + def __init__(self): + blocked_user_list = litellm.blocked_user_list + + if blocked_user_list is None: + raise Exception( + "`blocked_user_list` can either be a list or filepath. None set." + ) + + if isinstance(blocked_user_list, list): + self.blocked_user_list = blocked_user_list + + if isinstance(blocked_user_list, str): # assume it's a filepath + try: + with open(blocked_user_list, "r") as file: + data = file.read() + self.blocked_user_list = data.split("\n") + except FileNotFoundError: + raise Exception( + f"File not found. blocked_user_list={blocked_user_list}" + ) + except Exception as e: + raise Exception( + f"An error occurred: {str(e)}, blocked_user_list={blocked_user_list}" + ) + + def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): + if level == "INFO": + verbose_proxy_logger.info(print_statement) + elif level == "DEBUG": + verbose_proxy_logger.debug(print_statement) + + if litellm.set_verbose is True: + print(print_statement) # noqa + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + try: + """ + - check if user id part of call + - check if user id part of blocked list + """ + self.print_verbose(f"Inside Blocked User List Pre-Call Hook") + if "user_id" in data: + if data["user_id"] in self.blocked_user_list: + raise HTTPException( + status_code=400, + detail={ + "error": f"User blocked from making LLM API Calls. User={data['user_id']}" + }, + ) + except HTTPException as e: + raise e + except Exception as e: + traceback.print_exc() diff --git a/litellm/__init__.py b/litellm/__init__.py index 83bd98c46..9b3107b2d 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -60,6 +60,7 @@ llamaguard_model_name: Optional[str] = None presidio_ad_hoc_recognizers: Optional[str] = None google_moderation_confidence_threshold: Optional[float] = None llamaguard_unsafe_content_categories: Optional[str] = None +blocked_user_list: Optional[Union[str, List]] = None ################## logging: bool = True caching: bool = ( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7a1632fcd..541e1af00 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1479,6 +1479,16 @@ class ProxyConfig: llm_guard_moderation_obj = _ENTERPRISE_LLMGuard() imported_list.append(llm_guard_moderation_obj) + elif ( + isinstance(callback, str) + and callback == "blocked_user_check" + ): + from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import ( + _ENTERPRISE_BlockedUserList, + ) + + blocked_user_list = _ENTERPRISE_BlockedUserList() + imported_list.append(blocked_user_list) else: imported_list.append( get_instance_fn( diff --git a/litellm/tests/test_blocked_user_list.py b/litellm/tests/test_blocked_user_list.py new file mode 100644 index 000000000..b40d8296c --- /dev/null +++ b/litellm/tests/test_blocked_user_list.py @@ -0,0 +1,63 @@ +# What is this? +## This tests the blocked user pre call hook for the proxy server + + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import ( + _ENTERPRISE_BlockedUserList, +) +from litellm import Router, mock_completion +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache + + +@pytest.mark.asyncio +async def test_block_user_check(): + """ + - Set a blocked user as a litellm module value + - Test to see if a call with that user id is made, an error is raised + - Test to see if a call without that user is passes + """ + litellm.blocked_user_list = ["user_id_1"] + + blocked_user_obj = _ENTERPRISE_BlockedUserList() + + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) + local_cache = DualCache() + + ## Case 1: blocked user id passed + try: + await blocked_user_obj.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + call_type="completion", + data={"user_id": "user_id_1"}, + ) + pytest.fail(f"Expected call to fail") + except Exception as e: + pass + + ## Case 2: normal user id passed + try: + await blocked_user_obj.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + call_type="completion", + data={"user_id": "user_id_2"}, + ) + except Exception as e: + pytest.fail(f"An error occurred - {str(e)}")