litellm/tests/local_testing/test_blocked_user_list.py
Ishaan Jaff 51ffe93e77
(docs) add docstrings for all /key, /user, /team, /customer endpoints (#6804)
* use helper to handle_exception_on_proxy

* add doc string for /key/regenerate

* use 1 helper for handle_exception_on_proxy

* add doc string for /key/block

* add doc string for /key/unblock

* remove deprecated function

* remove deprecated endpoints

* remove incorrect tag for endpoint

* fix linting

* fix /key/regenerate

* fix regen key

* fix use port 4000 for user endpoints

* fix clean up - use separate file for customer endpoints

* add docstring for user/update

* fix imports

* doc string /user/list

* doc string for /team/delete

* fix team block endpoint

* fix import block user

* add doc string for /team/unblock

* add doc string for /team/list

* add doc string for /team/info

* add doc string for key endpoints

* fix customer_endpoints

* add doc string for customer endpoints

* fix import new_end_user

* fix testing

* fix import new_end_user

* fix add check for allow_user_auth
2024-11-18 19:44:06 -08:00

156 lines
4.5 KiB
Python

# What is this?
## This tests the blocked user pre call hook for the proxy server
import asyncio
import os
import random
import sys
import time
import traceback
from datetime import datetime
from dotenv import load_dotenv
from fastapi import Request
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import logging
import pytest
import litellm
from litellm import Router, mock_completion
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import (
new_user,
user_info,
user_update,
)
from litellm.proxy.management_endpoints.key_management_endpoints import (
delete_key_fn,
generate_key_fn,
generate_key_helper_fn,
info_key_fn,
update_key_fn,
)
from litellm.proxy.proxy_server import user_api_key_auth
from litellm.proxy.management_endpoints.customer_endpoints import block_user
from litellm.proxy.spend_tracking.spend_management_endpoints import (
spend_key_fn,
spend_user_fn,
view_spend_logs,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
verbose_proxy_logger.setLevel(level=logging.DEBUG)
from starlette.datastructures import URL
from litellm.caching.caching import DualCache
from litellm.proxy._types import (
BlockUsers,
DynamoDBArgs,
GenerateKeyRequest,
KeyRequest,
NewUserRequest,
UpdateKeyRequest,
)
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
@pytest.fixture
def prisma_client():
from litellm.proxy.proxy_cli import append_query_params
### add connection pool + pool timeout args
params = {"connection_limit": 100, "pool_timeout": 60}
database_url = os.getenv("DATABASE_URL")
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
# Assuming PrismaClient is a class that needs to be instantiated
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
# Reset litellm.proxy.proxy_server.prisma_client to None
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
f"litellm-proxy-budget-{time.time()}"
)
litellm.proxy.proxy_server.user_custom_key_generate = None
return prisma_client
@pytest.mark.asyncio
async def test_block_user_check(prisma_client):
"""
- Set a blocked user as a litellm module value
- Test to see if a call with that user id is made, an error is raised
- Test to see if a call without that user is passes
"""
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
litellm.blocked_user_list = ["user_id_1"]
blocked_user_obj = _ENTERPRISE_BlockedUserList(
prisma_client=litellm.proxy.proxy_server.prisma_client
)
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()
## Case 1: blocked user id passed
try:
await blocked_user_obj.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
call_type="completion",
data={"user_id": "user_id_1"},
)
pytest.fail(f"Expected call to fail")
except Exception as e:
pass
## Case 2: normal user id passed
try:
await blocked_user_obj.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
call_type="completion",
data={"user_id": "user_id_2"},
)
except Exception as e:
pytest.fail(f"An error occurred - {str(e)}")
@pytest.mark.asyncio
async def test_block_user_db_check(prisma_client):
"""
- Block end user via "/user/block"
- Check returned value
"""
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
_block_users = BlockUsers(user_ids=["user_id_1"])
result = await block_user(data=_block_users)
result = result["blocked_users"]
assert len(result) == 1
assert result[0].user_id == "user_id_1"
assert result[0].blocked == True