diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index e6f2437e7..2eb693cf4 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -2013,3 +2013,74 @@ async def test_master_key_hashing(prisma_client): except Exception as e: print("Got Exception", e) pytest.fail(f"Got exception {e}") + + +@pytest.mark.asyncio +async def test_reset_spend_authentication(prisma_client): + """ + 1. Test master key can access this route -> ONLY MASTER KEY SHOULD BE ABLE TO RESET SPEND + 2. Test that non-master key gets rejected + 3. Test that non-master key with role == "proxy_admin" or admin gets rejected + """ + + print("prisma client=", prisma_client) + + master_key = "sk-1234" + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", master_key) + + await litellm.proxy.proxy_server.prisma_client.connect() + from litellm.proxy.proxy_server import user_api_key_cache + + bearer_token = "Bearer " + master_key + + request = Request(scope={"type": "http"}) + request._url = URL(url="/global/spend/reset") + + # Test 1 - Master Key + result: UserAPIKeyAuth = await user_api_key_auth( + request=request, api_key=bearer_token + ) + + print("result from user auth with Master key", result) + assert result.token is not None + + # Test 2 - Non-Master Key + _response = await new_user( + data=NewUserRequest( + tpm_limit=20, + ) + ) + + generate_key = "Bearer " + _response.key + + try: + await user_api_key_auth(request=request, api_key=generate_key) + pytest.fail(f"This should have failed!. IT's an expired key") + except Exception as e: + print("Got Exception", e) + assert ( + "Tried to access route=/global/spend/reset, which is only for MASTER KEY" + in e.message + ) + + # Test 3 - Non-Master Key with role == "proxy_admin" or admin + _response = await new_user( + data=NewUserRequest( + user_role="proxy_admin", + tpm_limit=20, + ) + ) + + generate_key = "Bearer " + _response.key + + try: + await user_api_key_auth(request=request, api_key=generate_key) + pytest.fail(f"This should have failed!. IT's an expired key") + except Exception as e: + print("Got Exception", e) + assert ( + "Tried to access route=/global/spend/reset, which is only for MASTER KEY" + in e.message + )