diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 4982900948..72c59d478e 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -427,7 +427,7 @@ def run_server( if os.getenv("DATABASE_URL", None) is not None: try: ### add connection pool + pool timeout args - params = {"connection_limit": 500, "pool_timeout": 60} + params = {"connection_limit": 200, "pool_timeout": 10} database_url = os.getenv("DATABASE_URL") modified_url = append_query_params(database_url, params) os.environ["DATABASE_URL"] = modified_url diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b57f780895..865a06abe5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -870,146 +870,47 @@ async def update_database( verbose_proxy_logger.info( f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}" ) - - ### [TODO] STEP 1: GET KEY + USER SPEND ### (key, user) - - ### [TODO] STEP 2: UPDATE SPEND ### (key, user, spend logs) + payload = get_logging_payload( + kwargs=kwargs, + response_obj=completion_response, + start_time=start_time, + end_time=end_time, + ) + payload["spend"] = response_cost ### UPDATE USER SPEND ### async def _update_user_db(): """ - - Update that user's row - - Update litellm-proxy-budget row (global proxy spend) + Update user row + proxy budget """ user_ids = [user_id, litellm_proxy_budget_name] - data_list = [] - try: - for id in user_ids: - if id is None: - continue - if prisma_client is not None: - existing_spend_obj = await prisma_client.get_data(user_id=id) - elif ( - custom_db_client is not None and id != litellm_proxy_budget_name - ): - existing_spend_obj = await custom_db_client.get_data( - key=id, table_name="user" - ) - verbose_proxy_logger.debug( - f"Updating existing_spend_obj: {existing_spend_obj}" - ) - if existing_spend_obj is None: - # if user does not exist in LiteLLM_UserTable, create a new user - existing_spend = 0 - max_user_budget = None - if litellm.max_user_budget is not None: - max_user_budget = litellm.max_user_budget - existing_spend_obj = LiteLLM_UserTable( - user_id=id, - spend=0, - max_budget=max_user_budget, - user_email=None, - ) - else: - existing_spend = existing_spend_obj.spend - - # Calculate the new cost by adding the existing cost and response_cost - existing_spend_obj.spend = existing_spend + response_cost - - valid_token = user_api_key_cache.get_cache(key=id) - if valid_token is not None and isinstance(valid_token, dict): - user_api_key_cache.set_cache( - key=id, value=existing_spend_obj.json() - ) - - verbose_proxy_logger.debug(f"new cost: {existing_spend_obj.spend}") - data_list.append(existing_spend_obj) - - # Update the cost column for the given user id - if prisma_client is not None: - await prisma_client.update_data( - data_list=data_list, - query_type="update_many", - table_name="user", - ) - elif custom_db_client is not None and user_id is not None: - new_spend = data_list[0].spend - await custom_db_client.update_data( - key=user_id, value={"spend": new_spend}, table_name="user" - ) - except Exception as e: - verbose_proxy_logger.info(f"Update User DB call failed to execute") + user_ids_str = ", ".join( + f"'{id}'" for id in user_ids + ) # Enclose each id in single quotes + sql_query = f""" + UPDATE "LiteLLM_UserTable" + SET spend = spend + {response_cost} + WHERE user_id IN ({user_ids_str}) + """ + await prisma_client.sql_executor(sql_query=sql_query) ### UPDATE KEY SPEND ### async def _update_key_db(): - try: - verbose_proxy_logger.debug( - f"adding spend to key db. Response cost: {response_cost}. Token: {token}." - ) - if prisma_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await prisma_client.get_data(token=token) - verbose_proxy_logger.debug( - f"_update_key_db: existing spend: {existing_spend_obj}" - ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await prisma_client.update_data( - token=token, data={"spend": new_spend} - ) - - valid_token = user_api_key_cache.get_cache(key=token) - if valid_token is not None: - valid_token.spend = new_spend - user_api_key_cache.set_cache(key=token, value=valid_token) - elif custom_db_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await custom_db_client.get_data( - key=token, table_name="key" - ) - verbose_proxy_logger.debug( - f"_update_key_db existing spend: {existing_spend_obj}" - ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await custom_db_client.update_data( - key=token, value={"spend": new_spend}, table_name="key" - ) - - valid_token = user_api_key_cache.get_cache(key=token) - if valid_token is not None: - valid_token.spend = new_spend - user_api_key_cache.set_cache(key=token, value=valid_token) - except Exception as e: - verbose_proxy_logger.info(f"Update Key DB Call failed to execute") + """ + Update key row + """ + sql_query = f""" + UPDATE "LiteLLM_VerificationToken" + SET spend = spend + {response_cost} + WHERE token = '{token}' + """ + await prisma_client.sql_executor(sql_query=sql_query) ### UPDATE SPEND LOGS ### async def _insert_spend_log_to_db(): try: # Helper to generate payload to log verbose_proxy_logger.debug("inserting spend log to db") - payload = get_logging_payload( - kwargs=kwargs, - response_obj=completion_response, - start_time=start_time, - end_time=end_time, - ) - - payload["spend"] = response_cost if prisma_client is not None: await prisma_client.insert_data(data=payload, table_name="spend") diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8804028cc1..1049a9f430 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -911,6 +911,26 @@ class PrismaClient: response = await self.db.query_raw(sql_query) return response + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def sql_executor(self, sql_query): + """ + Executes sql queries against the prisma client + """ + + # Execute the raw query + # The asterisk before `user_id_list` unpacks the list into separate arguments + try: + response = await self.db.query_raw(sql_query) + return response + except Exception as e: + raise e + class DBClient: """