diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 05f9bdbcf4..e95a2ebba0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1948,12 +1948,6 @@ async def startup_event(): 597, 605 ) # random interval, so multiple workers avoid resetting budget at the same time scheduler.add_job(reset_budget, "interval", seconds=interval, args=[prisma_client]) - scheduler.add_job( - failed_transaction_writer, - "interval", - seconds=10, - args=[prisma_client, user_api_key_cache], - ) scheduler.start() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 5c827f1748..163dc5e2cf 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1274,6 +1274,48 @@ async def failed_transaction_writer( - Batch writes them to the table - Sets value in cache to None """ + ### UPDATE USER DB ### + existing_list = await user_api_key_cache.async_get_cache( + key="Failed_USER_DB_Transactions" + ) + if existing_list is not None and isinstance(existing_list, list): + try: + async with prisma_client.db.tx( + max_wait=timedelta(seconds=10), timeout=timedelta(seconds=30) + ) as transaction: + async with transaction.batch_() as batcher: + for item in existing_list: # [..., (response_cost, user_id)] + batcher.litellm_usertable.update( + where={"user_id": item[1]}, + data={"spend": {"increment": item[0]}}, + ) + await user_api_key_cache.async_set_cache( + key="Failed_USER_DB_Transactions", value=None + ) + except Exception as e: + pass + + ### UPDATE KEYS DB ### + existing_list = await user_api_key_cache.async_get_cache( + key="Failed_Keys_DB_Transactions" + ) + if existing_list is not None and isinstance(existing_list, list): + try: + async with prisma_client.db.tx( + max_wait=timedelta(seconds=10), timeout=timedelta(seconds=30) + ) as transaction: + async with transaction.batch_() as batcher: + for item in existing_list: # [..., (response_cost, token)] + batcher.litellm_verificationtoken.update( + where={"token": item[1]}, + data={"spend": {"increment": item[0]}}, + ) + await user_api_key_cache.async_set_cache( + key="Failed_Keys_DB_Transactions", value=None + ) + except Exception as e: + pass + ### UPDATE SPEND LOGS DB ### existing_list = await user_api_key_cache.async_get_cache( key="Failed_Spend_Logs_DB_Transactions" diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 4180652b3d..bd96b50ded 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -83,7 +83,7 @@ def prisma_client(): from litellm.proxy.proxy_cli import append_query_params ### add connection pool + pool timeout args - params = {"connection_limit": 500, "pool_timeout": 60} + params = {"connection_limit": 100, "pool_timeout": 60} database_url = os.getenv("DATABASE_URL") modified_url = append_query_params(database_url, params) os.environ["DATABASE_URL"] = modified_url @@ -1611,13 +1611,13 @@ async def test_proxy_load_test_db(prisma_client): result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - n = 1000 + n = 5000 tasks = [ track_cost_callback_helper_fn(generated_key=generated_key, user_id=user_id) for _ in range(n) ] completions = await asyncio.gather(*tasks) - await asyncio.sleep(10) + await asyncio.sleep(120) try: # call spend logs spend_logs = await view_spend_logs(api_key=generated_key)