diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 0d1f488e1..cdd7c8683 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -76,6 +76,7 @@ from litellm.proxy.proxy_server import ( user_api_key_auth, ) from litellm.proxy.spend_tracking.spend_management_endpoints import ( + global_spend, spend_key_fn, spend_user_fn, view_spend_logs, @@ -99,6 +100,7 @@ from litellm.proxy._types import ( ProxyException, UpdateKeyRequest, UpdateTeamRequest, + UpdateUserRequest, UserAPIKeyAuth, ) from litellm.proxy.utils import DBClient @@ -2488,3 +2490,58 @@ async def test_enforced_params(prisma_client): in e.message ) general_settings.pop("enforced_params") + + +@pytest.mark.asyncio() +async def test_update_user_role(prisma_client): + """ + Tests if we update user role, incorrect values are not stored in cache + -> create a user with role == INTERNAL_USER + -> access an Admin only route -> expect to fail + + -> update user role to == PROXY_ADMIN + -> access an Admin only route -> expect to succeed + """ + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + key = await new_user( + data=NewUserRequest( + user_role=LitellmUserRoles.INTERNAL_USER, + ) + ) + + print(key) + api_key = "Bearer " + key.key + + api_route = APIRoute(path="/global/spend", endpoint=global_spend) + request = Request( + { + "type": "http", + "route": api_route, + "path": "/global/spend", + "headers": [("Authorization", api_key)], + } + ) + + request._url = URL(url="/global/spend") + + # use generated key to auth in + try: + result = await user_api_key_auth(request=request, api_key=api_key) + print("result from user auth with new key", result) + except Exception as e: + print(e) + pass + + await user_update( + data=UpdateUserRequest( + user_id=key.user_id, user_role=LitellmUserRoles.PROXY_ADMIN + ) + ) + + await asyncio.sleep(2) + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=api_key) + print("result from user auth with new key", result)