fix(proxy/utils.py): move to batch writing db updates

This commit is contained in:
Krrish Dholakia 2024-03-16 22:32:00 -07:00
parent 710efab0de
commit 077b9c6234
4 changed files with 269 additions and 95 deletions

View file

@ -14,7 +14,8 @@ litellm_settings:
cache_params: cache_params:
type: redis type: redis
callbacks: ["batch_redis_requests"] callbacks: ["batch_redis_requests"]
# success_callbacks: ["langfuse"]
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
# database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require" database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require"

View file

@ -96,6 +96,8 @@ from litellm.proxy.utils import (
_is_user_proxy_admin, _is_user_proxy_admin,
_is_projected_spend_over_limit, _is_projected_spend_over_limit,
_get_projected_spend_over_limit, _get_projected_spend_over_limit,
update_spend,
monitor_spend_list,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
@ -1102,108 +1104,112 @@ async def update_database(
try: try:
if prisma_client is not None: # update if prisma_client is not None: # update
user_ids = [user_id, litellm_proxy_budget_name] user_ids = [user_id, litellm_proxy_budget_name]
### KEY CHANGE ###
for _id in user_ids:
prisma_client.user_list_transactons.append((_id, response_cost))
######
## do a group update for the user-id of the key + global proxy budget ## do a group update for the user-id of the key + global proxy budget
await prisma_client.db.litellm_usertable.update_many( # await prisma_client.db.litellm_usertable.update_many(
where={"user_id": {"in": user_ids}}, # where={"user_id": {"in": user_ids}},
data={"spend": {"increment": response_cost}}, # data={"spend": {"increment": response_cost}},
) # )
if end_user_id is not None: # if end_user_id is not None:
if existing_user_obj is None: # if existing_user_obj is None:
# if user does not exist in LiteLLM_UserTable, create a new user # # if user does not exist in LiteLLM_UserTable, create a new user
existing_spend = 0 # existing_spend = 0
max_user_budget = None # max_user_budget = None
if litellm.max_user_budget is not None: # if litellm.max_user_budget is not None:
max_user_budget = litellm.max_user_budget # max_user_budget = litellm.max_user_budget
existing_user_obj = LiteLLM_UserTable( # existing_user_obj = LiteLLM_UserTable(
user_id=end_user_id, # user_id=end_user_id,
spend=0, # spend=0,
max_budget=max_user_budget, # max_budget=max_user_budget,
user_email=None, # user_email=None,
) # )
else: # else:
existing_user_obj.spend = ( # existing_user_obj.spend = (
existing_user_obj.spend + response_cost # existing_user_obj.spend + response_cost
) # )
user_object_json = {**existing_user_obj.json(exclude_none=True)} # user_object_json = {**existing_user_obj.json(exclude_none=True)}
user_object_json["model_max_budget"] = json.dumps( # user_object_json["model_max_budget"] = json.dumps(
user_object_json["model_max_budget"] # user_object_json["model_max_budget"]
) # )
user_object_json["model_spend"] = json.dumps( # user_object_json["model_spend"] = json.dumps(
user_object_json["model_spend"] # user_object_json["model_spend"]
) # )
await prisma_client.db.litellm_usertable.upsert( # await prisma_client.db.litellm_usertable.upsert(
where={"user_id": end_user_id}, # where={"user_id": end_user_id},
data={ # data={
"create": user_object_json, # "create": user_object_json,
"update": {"spend": {"increment": response_cost}}, # "update": {"spend": {"increment": response_cost}},
}, # },
) # )
elif custom_db_client is not None: # elif custom_db_client is not None:
for id in user_ids: # for id in user_ids:
if id is None: # if id is None:
continue # continue
if ( # if (
custom_db_client is not None # custom_db_client is not None
and id != litellm_proxy_budget_name # and id != litellm_proxy_budget_name
): # ):
existing_spend_obj = await custom_db_client.get_data( # existing_spend_obj = await custom_db_client.get_data(
key=id, table_name="user" # key=id, table_name="user"
) # )
verbose_proxy_logger.debug( # verbose_proxy_logger.debug(
f"Updating existing_spend_obj: {existing_spend_obj}" # f"Updating existing_spend_obj: {existing_spend_obj}"
) # )
if existing_spend_obj is None: # if existing_spend_obj is None:
# if user does not exist in LiteLLM_UserTable, create a new user # # if user does not exist in LiteLLM_UserTable, create a new user
existing_spend = 0 # existing_spend = 0
max_user_budget = None # max_user_budget = None
if litellm.max_user_budget is not None: # if litellm.max_user_budget is not None:
max_user_budget = litellm.max_user_budget # max_user_budget = litellm.max_user_budget
existing_spend_obj = LiteLLM_UserTable( # existing_spend_obj = LiteLLM_UserTable(
user_id=id, # user_id=id,
spend=0, # spend=0,
max_budget=max_user_budget, # max_budget=max_user_budget,
user_email=None, # user_email=None,
) # )
else: # else:
existing_spend = existing_spend_obj.spend # existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost # # Calculate the new cost by adding the existing cost and response_cost
existing_spend_obj.spend = existing_spend + response_cost # existing_spend_obj.spend = existing_spend + response_cost
# track cost per model, for the given user # # track cost per model, for the given user
spend_per_model = existing_spend_obj.model_spend or {} # spend_per_model = existing_spend_obj.model_spend or {}
current_model = kwargs.get("model") # current_model = kwargs.get("model")
if current_model is not None and spend_per_model is not None: # if current_model is not None and spend_per_model is not None:
if spend_per_model.get(current_model) is None: # if spend_per_model.get(current_model) is None:
spend_per_model[current_model] = response_cost # spend_per_model[current_model] = response_cost
else: # else:
spend_per_model[current_model] += response_cost # spend_per_model[current_model] += response_cost
existing_spend_obj.model_spend = spend_per_model # existing_spend_obj.model_spend = spend_per_model
valid_token = user_api_key_cache.get_cache(key=id) # valid_token = user_api_key_cache.get_cache(key=id)
if valid_token is not None and isinstance(valid_token, dict): # if valid_token is not None and isinstance(valid_token, dict):
user_api_key_cache.set_cache( # user_api_key_cache.set_cache(
key=id, value=existing_spend_obj.json() # key=id, value=existing_spend_obj.json()
) # )
verbose_proxy_logger.debug( # verbose_proxy_logger.debug(
f"user - new cost: {existing_spend_obj.spend}, user_id: {id}" # f"user - new cost: {existing_spend_obj.spend}, user_id: {id}"
) # )
data_list.append(existing_spend_obj) # data_list.append(existing_spend_obj)
if custom_db_client is not None and user_id is not None: # if custom_db_client is not None and user_id is not None:
new_spend = data_list[0].spend # new_spend = data_list[0].spend
await custom_db_client.update_data( # await custom_db_client.update_data(
key=user_id, # key=user_id,
value={"spend": new_spend}, # value={"spend": new_spend},
table_name="user", # table_name="user",
) # )
except Exception as e: except Exception as e:
verbose_proxy_logger.info( verbose_proxy_logger.info(
"\033[91m" "\033[91m"
@ -1323,10 +1329,12 @@ async def update_database(
) )
raise e raise e
asyncio.create_task(_update_user_db()) # asyncio.create_task(_update_user_db())
asyncio.create_task(_update_key_db()) await _update_user_db()
asyncio.create_task(_update_team_db())
asyncio.create_task(_insert_spend_log_to_db()) # asyncio.create_task(_update_key_db())
# asyncio.create_task(_update_team_db())
# asyncio.create_task(_insert_spend_log_to_db())
verbose_proxy_logger.debug("Runs spend update on all tables") verbose_proxy_logger.debug("Runs spend update on all tables")
except Exception as e: except Exception as e:
@ -2632,6 +2640,10 @@ async def startup_event():
scheduler.add_job( scheduler.add_job(
reset_budget, "interval", seconds=interval, args=[prisma_client] reset_budget, "interval", seconds=interval, args=[prisma_client]
) )
# scheduler.add_job(
# monitor_spend_list, "interval", seconds=10, args=[prisma_client]
# )
scheduler.add_job(update_spend, "interval", seconds=60, args=[prisma_client])
scheduler.start() scheduler.start()

View file

@ -472,6 +472,11 @@ def on_backoff(details):
class PrismaClient: class PrismaClient:
user_list_transactons: List = []
key_list_transactons: List = []
team_list_transactons: List = []
spend_log_transactons: List = []
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose( print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
@ -1841,6 +1846,65 @@ async def reset_budget(prisma_client: PrismaClient):
) )
async def update_spend(
prisma_client: PrismaClient,
):
"""
Batch write updates to db.
Triggered every minute.
Requires:
user_id_list: list,
keys_list: list,
team_list: list,
spend_logs: list,
"""
n_retry_times = 3
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons) > 0:
for i in range(n_retry_times + 1):
try:
remaining_transactions = list(prisma_client.user_list_transactons)
while remaining_transactions:
batch_size = min(5000, len(remaining_transactions))
batch_transactions = remaining_transactions[:batch_size]
async with prisma_client.db.tx(timeout=60000) as transaction:
async with transaction.batch_() as batcher:
for user_id_tuple in batch_transactions:
user_id, response_cost = user_id_tuple
if user_id != "litellm-proxy-budget":
batcher.litellm_usertable.update(
where={"user_id": user_id},
data={"spend": {"increment": response_cost}},
)
remaining_transactions = remaining_transactions[batch_size:]
prisma_client.user_list_transactons = (
[]
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
raise e
### UPDATE KEY TABLE ###
### UPDATE TEAM TABLE ###
### UPDATE SPEND LOGS TABLE ###
async def monitor_spend_list(prisma_client: PrismaClient):
"""
Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db
"""
if len(prisma_client.user_list_transactons) > 10000:
await update_spend(prisma_client=prisma_client)
async def _read_request_body(request): async def _read_request_body(request):
""" """
Asynchronous function to read the request body and parse it as JSON or literal data. Asynchronous function to read the request body and parse it as JSON or literal data.

View file

@ -0,0 +1,97 @@
# What is this?
## This tests the batch update spend logic on the proxy server
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
from fastapi import Request
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
import pytest, logging, asyncio
import litellm, asyncio
from litellm.proxy.proxy_server import (
new_user,
generate_key_fn,
user_api_key_auth,
user_update,
delete_key_fn,
info_key_fn,
update_key_fn,
generate_key_fn,
generate_key_helper_fn,
spend_user_fn,
spend_key_fn,
view_spend_logs,
user_info,
block_user,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.setLevel(level=logging.DEBUG)
from litellm.proxy._types import (
NewUserRequest,
GenerateKeyRequest,
DynamoDBArgs,
KeyRequest,
UpdateKeyRequest,
GenerateKeyRequest,
BlockUsers,
)
from litellm.proxy.utils import DBClient
from starlette.datastructures import URL
from litellm.caching import DualCache
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
@pytest.fixture
def prisma_client():
from litellm.proxy.proxy_cli import append_query_params
### add connection pool + pool timeout args
params = {"connection_limit": 100, "pool_timeout": 60}
database_url = os.getenv("DATABASE_URL")
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
# Assuming DBClient is a class that needs to be instantiated
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
# Reset litellm.proxy.proxy_server.prisma_client to None
litellm.proxy.proxy_server.custom_db_client = None
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
f"litellm-proxy-budget-{time.time()}"
)
litellm.proxy.proxy_server.user_custom_key_generate = None
return prisma_client
@pytest.mark.asyncio
async def test_batch_update_spend(prisma_client):
prisma_client.user_list_transactons.append(("test-litellm-user-5", 23))
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
await update_spend(prisma_client=litellm.proxy.proxy_server.prisma_client)
raise Exception("it worked!")