# What is this? ## This tests the blocked user pre call hook for the proxy server import asyncio import os import random import sys import time import traceback from datetime import datetime from dotenv import load_dotenv from fastapi import Request load_dotenv() import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import asyncio import logging import pytest import litellm from litellm import Router, mock_completion from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import ( _ENTERPRISE_BlockedUserList, ) from litellm.proxy.management_endpoints.internal_user_endpoints import ( new_user, user_info, user_update, ) from litellm.proxy.management_endpoints.key_management_endpoints import ( delete_key_fn, generate_key_fn, generate_key_helper_fn, info_key_fn, update_key_fn, ) from litellm.proxy.proxy_server import user_api_key_auth from litellm.proxy.management_endpoints.customer_endpoints import block_user from litellm.proxy.spend_tracking.spend_management_endpoints import ( spend_key_fn, spend_user_fn, view_spend_logs, ) from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token verbose_proxy_logger.setLevel(level=logging.DEBUG) from starlette.datastructures import URL from litellm.caching.caching import DualCache from litellm.proxy._types import ( BlockUsers, DynamoDBArgs, GenerateKeyRequest, KeyRequest, NewUserRequest, UpdateKeyRequest, ) proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) @pytest.fixture def prisma_client(): from litellm.proxy.proxy_cli import append_query_params ### add connection pool + pool timeout args params = {"connection_limit": 100, "pool_timeout": 60} database_url = os.getenv("DATABASE_URL") modified_url = append_query_params(database_url, params) os.environ["DATABASE_URL"] = modified_url # Assuming PrismaClient is a class that needs to be instantiated prisma_client = PrismaClient( database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj ) # Reset litellm.proxy.proxy_server.prisma_client to None litellm.proxy.proxy_server.litellm_proxy_budget_name = ( f"litellm-proxy-budget-{time.time()}" ) litellm.proxy.proxy_server.user_custom_key_generate = None return prisma_client @pytest.mark.asyncio async def test_block_user_check(prisma_client): """ - Set a blocked user as a litellm module value - Test to see if a call with that user id is made, an error is raised - Test to see if a call without that user is passes """ setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") litellm.blocked_user_list = ["user_id_1"] blocked_user_obj = _ENTERPRISE_BlockedUserList( prisma_client=litellm.proxy.proxy_server.prisma_client ) _api_key = "sk-12345" _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) local_cache = DualCache() ## Case 1: blocked user id passed try: await blocked_user_obj.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=local_cache, call_type="completion", data={"user_id": "user_id_1"}, ) pytest.fail(f"Expected call to fail") except Exception as e: pass ## Case 2: normal user id passed try: await blocked_user_obj.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=local_cache, call_type="completion", data={"user_id": "user_id_2"}, ) except Exception as e: pytest.fail(f"An error occurred - {str(e)}") @pytest.mark.asyncio async def test_block_user_db_check(prisma_client): """ - Block end user via "/user/block" - Check returned value """ setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") await litellm.proxy.proxy_server.prisma_client.connect() _block_users = BlockUsers(user_ids=["user_id_1"]) result = await block_user(data=_block_users) result = result["blocked_users"] assert len(result) == 1 assert result[0].user_id == "user_id_1" assert result[0].blocked == True