diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 78f37ad34..588846bd3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -996,10 +996,8 @@ async def _PROXY_track_cost_callback( ) litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request") or {} - user_id = proxy_server_request.get("body", {}).get("user", None) - user_id = user_id or kwargs["litellm_params"]["metadata"].get( - "user_api_key_user_id", None - ) + end_user_id = proxy_server_request.get("body", {}).get("user", None) + user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None) if kwargs.get("response_cost", None) is not None: response_cost = kwargs["response_cost"] @@ -1013,9 +1011,6 @@ async def _PROXY_track_cost_callback( f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" ) - verbose_proxy_logger.info( - f"response_cost {response_cost}, for user_id {user_id}" - ) verbose_proxy_logger.debug( f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}" ) @@ -1025,6 +1020,7 @@ async def _PROXY_track_cost_callback( token=user_api_key, response_cost=response_cost, user_id=user_id, + end_user_id=end_user_id, team_id=team_id, kwargs=kwargs, completion_response=completion_response, @@ -1066,6 +1062,7 @@ async def update_database( token, response_cost, user_id=None, + end_user_id=None, team_id=None, kwargs=None, completion_response=None, @@ -1076,6 +1073,10 @@ async def update_database( verbose_proxy_logger.info( f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" ) + if isinstance(token, str) and token.startswith("sk-"): + hashed_token = hash_token(token=token) + else: + hashed_token = token ### UPDATE USER SPEND ### async def _update_user_db(): @@ -1084,11 +1085,6 @@ async def update_database( - Update litellm-proxy-budget row (global proxy spend) """ ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db - end_user_id = None - if isinstance(token, str) and token.startswith("sk-"): - hashed_token = hash_token(token=token) - else: - hashed_token = token existing_token_obj = await user_api_key_cache.async_get_cache( key=hashed_token ) @@ -1097,119 +1093,24 @@ async def update_database( existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if existing_user_obj is not None and isinstance(existing_user_obj, dict): existing_user_obj = LiteLLM_UserTable(**existing_user_obj) - if existing_token_obj.user_id != user_id: # an end-user id was passed in - end_user_id = user_id - user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name] - data_list = [] try: if prisma_client is not None: # update - user_ids = [user_id, litellm_proxy_budget_name] + user_ids = [user_id] + if ( + litellm.max_budget > 0 + ): # track global proxy budget, if user set max budget + user_ids.append(litellm_proxy_budget_name) ### KEY CHANGE ### for _id in user_ids: - prisma_client.user_list_transactons.append((_id, response_cost)) - ###### - ## do a group update for the user-id of the key + global proxy budget - # await prisma_client.db.litellm_usertable.update_many( - # where={"user_id": {"in": user_ids}}, - # data={"spend": {"increment": response_cost}}, - # ) - # if end_user_id is not None: - # if existing_user_obj is None: - # # if user does not exist in LiteLLM_UserTable, create a new user - # existing_spend = 0 - # max_user_budget = None - # if litellm.max_user_budget is not None: - # max_user_budget = litellm.max_user_budget - # existing_user_obj = LiteLLM_UserTable( - # user_id=end_user_id, - # spend=0, - # max_budget=max_user_budget, - # user_email=None, - # ) - - # else: - # existing_user_obj.spend = ( - # existing_user_obj.spend + response_cost - # ) - - # user_object_json = {**existing_user_obj.json(exclude_none=True)} - - # user_object_json["model_max_budget"] = json.dumps( - # user_object_json["model_max_budget"] - # ) - # user_object_json["model_spend"] = json.dumps( - # user_object_json["model_spend"] - # ) - - # await prisma_client.db.litellm_usertable.upsert( - # where={"user_id": end_user_id}, - # data={ - # "create": user_object_json, - # "update": {"spend": {"increment": response_cost}}, - # }, - # ) - - # elif custom_db_client is not None: - # for id in user_ids: - # if id is None: - # continue - # if ( - # custom_db_client is not None - # and id != litellm_proxy_budget_name - # ): - # existing_spend_obj = await custom_db_client.get_data( - # key=id, table_name="user" - # ) - # verbose_proxy_logger.debug( - # f"Updating existing_spend_obj: {existing_spend_obj}" - # ) - # if existing_spend_obj is None: - # # if user does not exist in LiteLLM_UserTable, create a new user - # existing_spend = 0 - # max_user_budget = None - # if litellm.max_user_budget is not None: - # max_user_budget = litellm.max_user_budget - # existing_spend_obj = LiteLLM_UserTable( - # user_id=id, - # spend=0, - # max_budget=max_user_budget, - # user_email=None, - # ) - # else: - # existing_spend = existing_spend_obj.spend - - # # Calculate the new cost by adding the existing cost and response_cost - # existing_spend_obj.spend = existing_spend + response_cost - - # # track cost per model, for the given user - # spend_per_model = existing_spend_obj.model_spend or {} - # current_model = kwargs.get("model") - - # if current_model is not None and spend_per_model is not None: - # if spend_per_model.get(current_model) is None: - # spend_per_model[current_model] = response_cost - # else: - # spend_per_model[current_model] += response_cost - # existing_spend_obj.model_spend = spend_per_model - - # valid_token = user_api_key_cache.get_cache(key=id) - # if valid_token is not None and isinstance(valid_token, dict): - # user_api_key_cache.set_cache( - # key=id, value=existing_spend_obj.json() - # ) - - # verbose_proxy_logger.debug( - # f"user - new cost: {existing_spend_obj.spend}, user_id: {id}" - # ) - # data_list.append(existing_spend_obj) - - # if custom_db_client is not None and user_id is not None: - # new_spend = data_list[0].spend - # await custom_db_client.update_data( - # key=user_id, - # value={"spend": new_spend}, - # table_name="user", - # ) + prisma_client.user_list_transactons[_id] = ( + response_cost + + prisma_client.user_list_transactons.get(_id, 0) + ) + if end_user_id is not None: + prisma_client.end_user_list_transactons[end_user_id] = ( + response_cost + + prisma_client.user_list_transactons.get(end_user_id, 0) + ) except Exception as e: verbose_proxy_logger.info( "\033[91m" @@ -1220,38 +1121,13 @@ async def update_database( async def _update_key_db(): try: verbose_proxy_logger.debug( - f"adding spend to key db. Response cost: {response_cost}. Token: {token}." + f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}." ) if prisma_client is not None: - await prisma_client.db.litellm_verificationtoken.update( - where={"token": token}, - data={"spend": {"increment": response_cost}}, + prisma_client.key_list_transactons[hashed_token] = ( + response_cost + + prisma_client.key_list_transactons.get(hashed_token, 0) ) - elif custom_db_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await custom_db_client.get_data( - key=token, table_name="key" - ) - verbose_proxy_logger.debug( - f"_update_key_db existing spend: {existing_spend_obj}" - ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await custom_db_client.update_data( - key=token, value={"spend": new_spend}, table_name="key" - ) - - valid_token = user_api_key_cache.get_cache(key=token) - if valid_token is not None: - valid_token.spend = new_spend - user_api_key_cache.set_cache(key=token, value=valid_token) except Exception as e: verbose_proxy_logger.info( f"Update Key DB Call failed to execute - {str(e)}\n{traceback.format_exc()}" @@ -1273,16 +1149,13 @@ async def update_database( payload["spend"] = response_cost if prisma_client is not None: await prisma_client.insert_data(data=payload, table_name="spend") - elif custom_db_client is not None: - await custom_db_client.insert_data(payload, table_name="spend") - except Exception as e: verbose_proxy_logger.info( f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" ) raise e - ### UPDATE KEY SPEND ### + ### UPDATE TEAM SPEND ### async def _update_team_db(): try: verbose_proxy_logger.debug( @@ -1294,46 +1167,19 @@ async def update_database( ) return if prisma_client is not None: - await prisma_client.db.litellm_teamtable.update( - where={"team_id": team_id}, - data={"spend": {"increment": response_cost}}, + prisma_client.team_list_transactons[team_id] = ( + response_cost + + prisma_client.team_list_transactons.get(team_id, 0) ) - elif custom_db_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await custom_db_client.get_data( - key=token, table_name="key" - ) - verbose_proxy_logger.debug( - f"_update_key_db existing spend: {existing_spend_obj}" - ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await custom_db_client.update_data( - key=token, value={"spend": new_spend}, table_name="key" - ) - - valid_token = user_api_key_cache.get_cache(key=token) - if valid_token is not None: - valid_token.spend = new_spend - user_api_key_cache.set_cache(key=token, value=valid_token) except Exception as e: verbose_proxy_logger.info( f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" ) raise e - # asyncio.create_task(_update_user_db()) - await _update_user_db() - - # asyncio.create_task(_update_key_db()) - # asyncio.create_task(_update_team_db()) + asyncio.create_task(_update_user_db()) + asyncio.create_task(_update_key_db()) + asyncio.create_task(_update_team_db()) # asyncio.create_task(_insert_spend_log_to_db()) verbose_proxy_logger.debug("Runs spend update on all tables") @@ -2237,7 +2083,6 @@ async def generate_key_helper_fn( saved_token["expires"] = saved_token["expires"].isoformat() if prisma_client is not None: ## CREATE USER (If necessary) - verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") if query_type == "insert_data": user_row = await prisma_client.insert_data( data=user_data, table_name="user" @@ -2575,7 +2420,6 @@ async def startup_event(): # add master key to db if os.getenv("PROXY_ADMIN_ID", None) is not None: litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID") - asyncio.create_task( generate_key_helper_fn( duration=None, @@ -2640,10 +2484,7 @@ async def startup_event(): scheduler.add_job( reset_budget, "interval", seconds=interval, args=[prisma_client] ) - # scheduler.add_job( - # monitor_spend_list, "interval", seconds=10, args=[prisma_client] - # ) - scheduler.add_job(update_spend, "interval", seconds=60, args=[prisma_client]) + scheduler.add_job(update_spend, "interval", seconds=10, args=[prisma_client]) scheduler.start() diff --git a/litellm/proxy/tests/load_test_completion.py b/litellm/proxy/tests/load_test_completion.py index 3f0da2e94..9450c1cb5 100644 --- a/litellm/proxy/tests/load_test_completion.py +++ b/litellm/proxy/tests/load_test_completion.py @@ -7,6 +7,7 @@ from dotenv import load_dotenv litellm_client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key="sk-1234") + async def litellm_completion(): # Your existing code for litellm_completion goes here try: @@ -18,6 +19,7 @@ async def litellm_completion(): "content": f"{text}. Who was alexander the great? {uuid.uuid4()}", } ], + user="my-new-end-user-1", ) return response @@ -29,9 +31,9 @@ async def litellm_completion(): async def main(): - for i in range(6): + for i in range(3): start = time.time() - n = 20 # Number of concurrent tasks + n = 10 # Number of concurrent tasks tasks = [litellm_completion() for _ in range(n)] chat_completions = await asyncio.gather(*tasks) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 17f4f9842..cc41c8ec8 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -7,6 +7,10 @@ from litellm.proxy._types import ( LiteLLM_VerificationToken, LiteLLM_VerificationTokenView, LiteLLM_SpendLogs, + LiteLLM_UserTable, + LiteLLM_EndUserTable, + LiteLLM_TeamTable, + Member, ) from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import ( @@ -472,9 +476,10 @@ def on_backoff(details): class PrismaClient: - user_list_transactons: List = [] - key_list_transactons: List = [] - team_list_transactons: List = [] + user_list_transactons: dict = {} + end_user_list_transactons: dict = {} + key_list_transactons: dict = {} + team_list_transactons: dict = {} spend_log_transactons: List = [] def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): @@ -1855,34 +1860,62 @@ async def update_spend( Triggered every minute. Requires: - user_id_list: list, + user_id_list: dict, keys_list: list, team_list: list, spend_logs: list, """ + verbose_proxy_logger.debug( + f"ENTERS UPDATE SPEND - len(prisma_client.user_list_transactons.keys()): {len(prisma_client.user_list_transactons.keys())}" + ) n_retry_times = 3 ### UPDATE USER TABLE ### - if len(prisma_client.user_list_transactons) > 0: + if len(prisma_client.user_list_transactons.keys()) > 0: for i in range(n_retry_times + 1): try: - remaining_transactions = list(prisma_client.user_list_transactons) - while remaining_transactions: - batch_size = min(5000, len(remaining_transactions)) - batch_transactions = remaining_transactions[:batch_size] - async with prisma_client.db.tx(timeout=60000) as transaction: - async with transaction.batch_() as batcher: - for user_id_tuple in batch_transactions: - user_id, response_cost = user_id_tuple - if user_id != "litellm-proxy-budget": - batcher.litellm_usertable.update( - where={"user_id": user_id}, - data={"spend": {"increment": response_cost}}, - ) - - remaining_transactions = remaining_transactions[batch_size:] - + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + user_id, + response_cost, + ) in prisma_client.user_list_transactons.items(): + batcher.litellm_usertable.update_many( + where={"user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) prisma_client.user_list_transactons = ( - [] + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + + ### UPDATE END-USER TABLE ### + if len(prisma_client.end_user_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + end_user_id, + response_cost, + ) in prisma_client.end_user_list_transactons.items(): + max_user_budget = None + if litellm.max_user_budget is not None: + max_user_budget = litellm.max_user_budget + new_user_obj = LiteLLM_EndUserTable( + user_id=end_user_id, spend=response_cost, blocked=False + ) + batcher.litellm_endusertable.update_many( + where={"user_id": end_user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.end_user_list_transactons = ( + {} ) # Clear the remaining transactions after processing all batches in the loop. except httpx.ReadTimeout: if i >= n_retry_times: # If we've reached the maximum number of retries @@ -1893,7 +1926,55 @@ async def update_spend( raise e ### UPDATE KEY TABLE ### + if len(prisma_client.key_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + token, + response_cost, + ) in prisma_client.key_list_transactons.items(): + batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists + where={"token": token}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.key_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + ### UPDATE TEAM TABLE ### + if len(prisma_client.team_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + team_id, + response_cost, + ) in prisma_client.team_list_transactons.items(): + batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists + where={"team_id": team_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.team_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + ### UPDATE SPEND LOGS TABLE ###