diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 2fd6baba2..286ccfeea 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -636,29 +636,39 @@ async def update_database( ### UPDATE USER SPEND ### async def _update_user_db(): - if user_id is None: - return - if prisma_client is not None: - existing_spend_obj = await prisma_client.get_data(user_id=user_id) - elif custom_db_client is not None: - existing_spend_obj = await custom_db_client.get_data( - key=user_id, table_name="user" - ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend + """ + - Update that user's row + - Update litellm-proxy-budget row (global proxy spend) + """ + user_ids = [user_id, "litellm-proxy-budget"] + data_list = [] + for id in user_ids: + if id is None: + continue + if prisma_client is not None: + existing_spend_obj = await prisma_client.get_data(user_id=id) + elif custom_db_client is not None: + existing_spend_obj = await custom_db_client.get_data( + key=id, table_name="user" + ) + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost + # Calculate the new cost by adding the existing cost and response_cost + existing_spend_obj.spend = existing_spend + response_cost + + verbose_proxy_logger.debug(f"new cost: {existing_spend_obj.spend}") + data_list.append(existing_spend_obj) - verbose_proxy_logger.debug(f"new cost: {new_spend}") # Update the cost column for the given user id if prisma_client is not None: await prisma_client.update_data( - user_id=user_id, data={"spend": new_spend} + data_list=data_list, query_type="update_many", table_name="user" ) - elif custom_db_client is not None: + elif custom_db_client is not None and user_id is not None: + new_spend = data_list[0].spend await custom_db_client.update_data( key=user_id, value={"spend": new_spend}, table_name="user" ) @@ -1563,7 +1573,13 @@ async def startup_event(): if prisma_client is not None and master_key is not None: # add master key to db await generate_key_helper_fn( - duration=None, models=[], aliases={}, config={}, spend=0, token=master_key + duration=None, + models=[], + aliases={}, + config={}, + spend=0, + token=master_key, + user_id="default_user_id", ) if ( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 4e6b88b1c..8d06106c0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -634,7 +634,7 @@ class PrismaClient: "update": {}, # don't do anything if it already exists }, ) - print_verbose( + verbose_proxy_logger.info( "\033[91m" + f"DB User Table - update succeeded {update_user_row}" + "\033[0m" @@ -678,6 +678,7 @@ class PrismaClient: Batch write update queries """ batcher = self.db.batch_() + verbose_proxy_logger.debug(f"data list for user table: {data_list}") for idx, user in enumerate(data_list): try: data_json = self.jsonify_object(data=user.model_dump()) @@ -688,8 +689,8 @@ class PrismaClient: data={**data_json}, # type: ignore ) await batcher.commit() - print_verbose( - "\033[91m" + f"DB User Table update succeeded" + "\033[0m" + verbose_proxy_logger.info( + "\033[91m" + f"DB User Table Batch update succeeded" + "\033[0m" ) except Exception as e: asyncio.create_task( diff --git a/litellm/utils.py b/litellm/utils.py index 03d38ff35..b41a554d7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1090,7 +1090,12 @@ class Logging: else: # streaming chunks + image gen. self.model_call_details["response_cost"] = None - if litellm.max_budget and self.stream: + if ( + litellm.max_budget + and self.stream + and result is not None + and "content" in result + ): time_diff = (end_time - start_time).total_seconds() float_diff = float(time_diff) litellm._current_cost += litellm.completion_cost(