From fd9c7a90af48bd9db3db8b10a267e4671d3bf96c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 6 Feb 2024 23:06:05 -0800 Subject: [PATCH] fix(proxy_server.py): update user cache to with new spend --- litellm/_logging.py | 2 +- litellm/proxy/proxy_server.py | 29 ++++++++++++++++++++++------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/litellm/_logging.py b/litellm/_logging.py index 171761c3c2..438fa9743d 100644 --- a/litellm/_logging.py +++ b/litellm/_logging.py @@ -3,7 +3,7 @@ import logging set_verbose = False # Create a handler for the logger (you may need to adapt this based on your needs) -handler = logging.FileHandler("log_file.txt") +handler = logging.StreamHandler() handler.setLevel(logging.DEBUG) # Create a formatter and set it for the handler diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5eb20e22ef..6b89b0f1b5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -432,10 +432,18 @@ async def user_api_key_auth( # Check 2. If user_id for this token is in budget ## Check 2.5 If global proxy is in budget if valid_token.user_id is not None: - user_id_information = user_api_key_cache.get_cache( - key=valid_token.user_id - ) - if user_id_information is None: + user_id_list = [valid_token.user_id, litellm_proxy_budget_name] + user_id_information = None + for id in user_id_list: + value = user_api_key_cache.get_cache(key=id) + if value is not None: + if user_id_information is None: + user_id_information = [] + user_id_information.append(value) + if user_id_information is None or ( + isinstance(user_id_information, list) + and len(user_id_information) < 2 + ): if prisma_client is not None: user_id_information = await prisma_client.get_data( user_id_list=[ @@ -445,13 +453,14 @@ async def user_api_key_auth( table_name="user", query_type="find_all", ) + for _id in user_id_information: + user_api_key_cache.set_cache( + key=_id["user_id"], value=_id, ttl=600 + ) if custom_db_client is not None: user_id_information = await custom_db_client.get_data( key=valid_token.user_id, table_name="user" ) - user_api_key_cache.set_cache( - key=valid_token.user_id, value=user_id_information, ttl=600 - ) verbose_proxy_logger.debug( f"user_id_information: {user_id_information}" @@ -879,6 +888,12 @@ async def update_database( # Calculate the new cost by adding the existing cost and response_cost existing_spend_obj.spend = existing_spend + response_cost + valid_token = user_api_key_cache.get_cache(key=id) + if valid_token is not None and isinstance(valid_token, dict): + user_api_key_cache.set_cache( + key=id, value=existing_spend_obj.json() + ) + verbose_proxy_logger.debug(f"new cost: {existing_spend_obj.spend}") data_list.append(existing_spend_obj)