forked from phoenix/litellm-mirror
* use helper to handle_exception_on_proxy * add doc string for /key/regenerate * use 1 helper for handle_exception_on_proxy * add doc string for /key/block * add doc string for /key/unblock * remove deprecated function * remove deprecated endpoints * remove incorrect tag for endpoint * fix linting * fix /key/regenerate * fix regen key * fix use port 4000 for user endpoints * fix clean up - use separate file for customer endpoints * add docstring for user/update * fix imports * doc string /user/list * doc string for /team/delete * fix team block endpoint * fix import block user * add doc string for /team/unblock * add doc string for /team/list * add doc string for /team/info * add doc string for key endpoints * fix customer_endpoints * add doc string for customer endpoints * fix import new_end_user * fix testing * fix import new_end_user * fix add check for allow_user_auth
156 lines
4.5 KiB
Python
156 lines
4.5 KiB
Python
# What is this?
|
|
## This tests the blocked user pre call hook for the proxy server
|
|
|
|
|
|
import asyncio
|
|
import os
|
|
import random
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from datetime import datetime
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import Request
|
|
|
|
load_dotenv()
|
|
import os
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import asyncio
|
|
import logging
|
|
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm import Router, mock_completion
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
|
|
_ENTERPRISE_BlockedUserList,
|
|
)
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
|
new_user,
|
|
user_info,
|
|
user_update,
|
|
)
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
delete_key_fn,
|
|
generate_key_fn,
|
|
generate_key_helper_fn,
|
|
info_key_fn,
|
|
update_key_fn,
|
|
)
|
|
from litellm.proxy.proxy_server import user_api_key_auth
|
|
from litellm.proxy.management_endpoints.customer_endpoints import block_user
|
|
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
|
spend_key_fn,
|
|
spend_user_fn,
|
|
view_spend_logs,
|
|
)
|
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
|
|
|
|
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
|
|
|
from starlette.datastructures import URL
|
|
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.proxy._types import (
|
|
BlockUsers,
|
|
DynamoDBArgs,
|
|
GenerateKeyRequest,
|
|
KeyRequest,
|
|
NewUserRequest,
|
|
UpdateKeyRequest,
|
|
)
|
|
|
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
|
|
|
|
|
@pytest.fixture
|
|
def prisma_client():
|
|
from litellm.proxy.proxy_cli import append_query_params
|
|
|
|
### add connection pool + pool timeout args
|
|
params = {"connection_limit": 100, "pool_timeout": 60}
|
|
database_url = os.getenv("DATABASE_URL")
|
|
modified_url = append_query_params(database_url, params)
|
|
os.environ["DATABASE_URL"] = modified_url
|
|
|
|
# Assuming PrismaClient is a class that needs to be instantiated
|
|
prisma_client = PrismaClient(
|
|
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
|
|
)
|
|
|
|
# Reset litellm.proxy.proxy_server.prisma_client to None
|
|
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
|
|
f"litellm-proxy-budget-{time.time()}"
|
|
)
|
|
litellm.proxy.proxy_server.user_custom_key_generate = None
|
|
|
|
return prisma_client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_block_user_check(prisma_client):
|
|
"""
|
|
- Set a blocked user as a litellm module value
|
|
- Test to see if a call with that user id is made, an error is raised
|
|
- Test to see if a call without that user is passes
|
|
"""
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
|
|
litellm.blocked_user_list = ["user_id_1"]
|
|
|
|
blocked_user_obj = _ENTERPRISE_BlockedUserList(
|
|
prisma_client=litellm.proxy.proxy_server.prisma_client
|
|
)
|
|
|
|
_api_key = "sk-12345"
|
|
_api_key = hash_token("sk-12345")
|
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
|
local_cache = DualCache()
|
|
|
|
## Case 1: blocked user id passed
|
|
try:
|
|
await blocked_user_obj.async_pre_call_hook(
|
|
user_api_key_dict=user_api_key_dict,
|
|
cache=local_cache,
|
|
call_type="completion",
|
|
data={"user_id": "user_id_1"},
|
|
)
|
|
pytest.fail(f"Expected call to fail")
|
|
except Exception as e:
|
|
pass
|
|
|
|
## Case 2: normal user id passed
|
|
try:
|
|
await blocked_user_obj.async_pre_call_hook(
|
|
user_api_key_dict=user_api_key_dict,
|
|
cache=local_cache,
|
|
call_type="completion",
|
|
data={"user_id": "user_id_2"},
|
|
)
|
|
except Exception as e:
|
|
pytest.fail(f"An error occurred - {str(e)}")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_block_user_db_check(prisma_client):
|
|
"""
|
|
- Block end user via "/user/block"
|
|
- Check returned value
|
|
"""
|
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
_block_users = BlockUsers(user_ids=["user_id_1"])
|
|
result = await block_user(data=_block_users)
|
|
result = result["blocked_users"]
|
|
assert len(result) == 1
|
|
assert result[0].user_id == "user_id_1"
|
|
assert result[0].blocked == True
|