forked from phoenix/litellm-mirror
feat(proxy_server.py): add support for blocked user lists (enterprise-only)
This commit is contained in:
parent
5be70b94e7
commit
028f455ad0
4 changed files with 154 additions and 0 deletions
80
enterprise/enterprise_hooks/blocked_user_list.py
Normal file
80
enterprise/enterprise_hooks/blocked_user_list.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
# +------------------------------+
|
||||||
|
#
|
||||||
|
# Blocked User List
|
||||||
|
#
|
||||||
|
# +------------------------------+
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
## This accepts a list of user id's for whom calls will be rejected
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Literal
|
||||||
|
import litellm
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import json, traceback
|
||||||
|
|
||||||
|
|
||||||
|
class _ENTERPRISE_BlockedUserList(CustomLogger):
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
blocked_user_list = litellm.blocked_user_list
|
||||||
|
|
||||||
|
if blocked_user_list is None:
|
||||||
|
raise Exception(
|
||||||
|
"`blocked_user_list` can either be a list or filepath. None set."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(blocked_user_list, list):
|
||||||
|
self.blocked_user_list = blocked_user_list
|
||||||
|
|
||||||
|
if isinstance(blocked_user_list, str): # assume it's a filepath
|
||||||
|
try:
|
||||||
|
with open(blocked_user_list, "r") as file:
|
||||||
|
data = file.read()
|
||||||
|
self.blocked_user_list = data.split("\n")
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise Exception(
|
||||||
|
f"File not found. blocked_user_list={blocked_user_list}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"An error occurred: {str(e)}, blocked_user_list={blocked_user_list}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
|
||||||
|
if level == "INFO":
|
||||||
|
verbose_proxy_logger.info(print_statement)
|
||||||
|
elif level == "DEBUG":
|
||||||
|
verbose_proxy_logger.debug(print_statement)
|
||||||
|
|
||||||
|
if litellm.set_verbose is True:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
- check if user id part of call
|
||||||
|
- check if user id part of blocked list
|
||||||
|
"""
|
||||||
|
self.print_verbose(f"Inside Blocked User List Pre-Call Hook")
|
||||||
|
if "user_id" in data:
|
||||||
|
if data["user_id"] in self.blocked_user_list:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"User blocked from making LLM API Calls. User={data['user_id']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
|
@ -60,6 +60,7 @@ llamaguard_model_name: Optional[str] = None
|
||||||
presidio_ad_hoc_recognizers: Optional[str] = None
|
presidio_ad_hoc_recognizers: Optional[str] = None
|
||||||
google_moderation_confidence_threshold: Optional[float] = None
|
google_moderation_confidence_threshold: Optional[float] = None
|
||||||
llamaguard_unsafe_content_categories: Optional[str] = None
|
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||||
|
blocked_user_list: Optional[Union[str, List]] = None
|
||||||
##################
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
|
|
|
@ -1479,6 +1479,16 @@ class ProxyConfig:
|
||||||
|
|
||||||
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
|
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
|
||||||
imported_list.append(llm_guard_moderation_obj)
|
imported_list.append(llm_guard_moderation_obj)
|
||||||
|
elif (
|
||||||
|
isinstance(callback, str)
|
||||||
|
and callback == "blocked_user_check"
|
||||||
|
):
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
|
||||||
|
_ENTERPRISE_BlockedUserList,
|
||||||
|
)
|
||||||
|
|
||||||
|
blocked_user_list = _ENTERPRISE_BlockedUserList()
|
||||||
|
imported_list.append(blocked_user_list)
|
||||||
else:
|
else:
|
||||||
imported_list.append(
|
imported_list.append(
|
||||||
get_instance_fn(
|
get_instance_fn(
|
||||||
|
|
63
litellm/tests/test_blocked_user_list.py
Normal file
63
litellm/tests/test_blocked_user_list.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
# What is this?
|
||||||
|
## This tests the blocked user pre call hook for the proxy server
|
||||||
|
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
|
||||||
|
_ENTERPRISE_BlockedUserList,
|
||||||
|
)
|
||||||
|
from litellm import Router, mock_completion
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_block_user_check():
|
||||||
|
"""
|
||||||
|
- Set a blocked user as a litellm module value
|
||||||
|
- Test to see if a call with that user id is made, an error is raised
|
||||||
|
- Test to see if a call without that user is passes
|
||||||
|
"""
|
||||||
|
litellm.blocked_user_list = ["user_id_1"]
|
||||||
|
|
||||||
|
blocked_user_obj = _ENTERPRISE_BlockedUserList()
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||||
|
local_cache = DualCache()
|
||||||
|
|
||||||
|
## Case 1: blocked user id passed
|
||||||
|
try:
|
||||||
|
await blocked_user_obj.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
call_type="completion",
|
||||||
|
data={"user_id": "user_id_1"},
|
||||||
|
)
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
## Case 2: normal user id passed
|
||||||
|
try:
|
||||||
|
await blocked_user_obj.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
call_type="completion",
|
||||||
|
data={"user_id": "user_id_2"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An error occurred - {str(e)}")
|
Loading…
Add table
Add a link
Reference in a new issue