From cf090acb2508c03f121a1bf69afb6a2ecb6cc1b7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 13 Mar 2024 16:13:37 -0700 Subject: [PATCH 1/3] fix(proxy_server.py): move to using UPDATE + SET for track_cost_callback --- litellm/proxy/proxy_server.py | 375 ++++++++++++++++++---------------- 1 file changed, 197 insertions(+), 178 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b30fced3d..e374686e0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -387,6 +387,7 @@ async def user_api_key_auth( user_api_key_cache.set_cache( key=hash_token(master_key), value=_user_api_key_obj ) + return _user_api_key_obj if isinstance( @@ -1007,6 +1008,8 @@ async def _PROXY_track_cost_callback( start_time=start_time, end_time=end_time, ) + + await update_cache(token=user_api_key, response_cost=response_cost) else: raise Exception("User API key missing from custom callback.") else: @@ -1049,10 +1052,6 @@ async def update_database( f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" ) - ### [TODO] STEP 1: GET KEY + USER SPEND ### (key, user) - - ### [TODO] STEP 2: UPDATE SPEND ### (key, user, spend logs) - ### UPDATE USER SPEND ### async def _update_user_db(): """ @@ -1062,72 +1061,73 @@ async def update_database( user_ids = [user_id, litellm_proxy_budget_name] data_list = [] try: - for id in user_ids: - if id is None: - continue - if prisma_client is not None: - existing_spend_obj = await prisma_client.get_data(user_id=id) - elif ( - custom_db_client is not None and id != litellm_proxy_budget_name - ): - existing_spend_obj = await custom_db_client.get_data( - key=id, table_name="user" - ) - verbose_proxy_logger.debug( - f"Updating existing_spend_obj: {existing_spend_obj}" + if prisma_client is not None: # update + user_ids = [user_id, litellm_proxy_budget_name] + await prisma_client.db.litellm_usertable.update( + where={"user_id": {"in": user_ids}}, + data={"spend": {"increment": response_cost}}, ) - if existing_spend_obj is None: - # if user does not exist in LiteLLM_UserTable, create a new user - existing_spend = 0 - max_user_budget = None - if litellm.max_user_budget is not None: - max_user_budget = litellm.max_user_budget - existing_spend_obj = LiteLLM_UserTable( - user_id=id, - spend=0, - max_budget=max_user_budget, - user_email=None, + elif custom_db_client is not None: + for id in user_ids: + if id is None: + continue + if ( + custom_db_client is not None + and id != litellm_proxy_budget_name + ): + existing_spend_obj = await custom_db_client.get_data( + key=id, table_name="user" + ) + verbose_proxy_logger.debug( + f"Updating existing_spend_obj: {existing_spend_obj}" ) - else: - existing_spend = existing_spend_obj.spend - - # Calculate the new cost by adding the existing cost and response_cost - existing_spend_obj.spend = existing_spend + response_cost - - # track cost per model, for the given user - spend_per_model = existing_spend_obj.model_spend or {} - current_model = kwargs.get("model") - - if current_model is not None and spend_per_model is not None: - if spend_per_model.get(current_model) is None: - spend_per_model[current_model] = response_cost + if existing_spend_obj is None: + # if user does not exist in LiteLLM_UserTable, create a new user + existing_spend = 0 + max_user_budget = None + if litellm.max_user_budget is not None: + max_user_budget = litellm.max_user_budget + existing_spend_obj = LiteLLM_UserTable( + user_id=id, + spend=0, + max_budget=max_user_budget, + user_email=None, + ) else: - spend_per_model[current_model] += response_cost - existing_spend_obj.model_spend = spend_per_model + existing_spend = existing_spend_obj.spend - valid_token = user_api_key_cache.get_cache(key=id) - if valid_token is not None and isinstance(valid_token, dict): - user_api_key_cache.set_cache( - key=id, value=existing_spend_obj.json() + # Calculate the new cost by adding the existing cost and response_cost + existing_spend_obj.spend = existing_spend + response_cost + + # track cost per model, for the given user + spend_per_model = existing_spend_obj.model_spend or {} + current_model = kwargs.get("model") + + if current_model is not None and spend_per_model is not None: + if spend_per_model.get(current_model) is None: + spend_per_model[current_model] = response_cost + else: + spend_per_model[current_model] += response_cost + existing_spend_obj.model_spend = spend_per_model + + valid_token = user_api_key_cache.get_cache(key=id) + if valid_token is not None and isinstance(valid_token, dict): + user_api_key_cache.set_cache( + key=id, value=existing_spend_obj.json() + ) + + verbose_proxy_logger.debug( + f"user - new cost: {existing_spend_obj.spend}, user_id: {id}" ) + data_list.append(existing_spend_obj) - verbose_proxy_logger.debug( - f"user - new cost: {existing_spend_obj.spend}, user_id: {id}" - ) - data_list.append(existing_spend_obj) - - if custom_db_client is not None and user_id is not None: - new_spend = data_list[0].spend - await custom_db_client.update_data( - key=user_id, value={"spend": new_spend}, table_name="user" - ) - # Update the cost column for the given user id - if prisma_client is not None: - await prisma_client.update_data( - data_list=data_list, - query_type="update_many", - table_name="user", - ) + if custom_db_client is not None and user_id is not None: + new_spend = data_list[0].spend + await custom_db_client.update_data( + key=user_id, + value={"spend": new_spend}, + table_name="user", + ) except Exception as e: verbose_proxy_logger.info( f"Update User DB call failed to execute {str(e)}" @@ -1140,82 +1140,10 @@ async def update_database( f"adding spend to key db. Response cost: {response_cost}. Token: {token}." ) if prisma_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await prisma_client.get_data(token=token) - verbose_proxy_logger.debug( - f"_update_key_db: existing spend: {existing_spend_obj}" + await prisma_client.db.litellm_verificationtoken.update( + where={"token": token}, + data={"spend": {"increment": response_cost}}, ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - - ## CHECK IF USER PROJECTED SPEND > SOFT LIMIT - soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown - if ( - existing_spend_obj.soft_budget_cooldown == False - and existing_spend_obj.litellm_budget_table is not None - and ( - _is_projected_spend_over_limit( - current_spend=new_spend, - soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget, - ) - == True - ) - ): - key_alias = existing_spend_obj.key_alias - projected_spend, projected_exceeded_date = ( - _get_projected_spend_over_limit( - current_spend=new_spend, - soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget, - ) - ) - soft_limit = existing_spend_obj.litellm_budget_table.soft_budget - user_info = { - "key_alias": key_alias, - "projected_spend": projected_spend, - "projected_exceeded_date": projected_exceeded_date, - } - # alert user - asyncio.create_task( - proxy_logging_obj.budget_alerts( - type="projected_limit_exceeded", - user_info=user_info, - user_max_budget=soft_limit, - user_current_spend=new_spend, - ) - ) - # set cooldown on alert - soft_budget_cooldown = True - # track cost per model, for the given key - spend_per_model = existing_spend_obj.model_spend or {} - current_model = kwargs.get("model") - if current_model is not None and spend_per_model is not None: - if spend_per_model.get(current_model) is None: - spend_per_model[current_model] = response_cost - else: - spend_per_model[current_model] += response_cost - - verbose_proxy_logger.debug( - f"new cost: {new_spend}, new spend per model: {spend_per_model}" - ) - # Update the cost column for the given token - await prisma_client.update_data( - token=token, - data={ - "spend": new_spend, - "model_spend": spend_per_model, - "soft_budget_cooldown": soft_budget_cooldown, - }, - ) - - valid_token = user_api_key_cache.get_cache(key=token) - if valid_token is not None: - valid_token.spend = new_spend - valid_token.model_spend = spend_per_model - user_api_key_cache.set_cache(key=token, value=valid_token) elif custom_db_client is not None: # Fetch the existing cost for the given token existing_spend_obj = await custom_db_client.get_data( @@ -1246,6 +1174,7 @@ async def update_database( verbose_proxy_logger.info( f"Update Key DB Call failed to execute - {str(e)}" ) + raise e ### UPDATE SPEND LOGS ### async def _insert_spend_log_to_db(): @@ -1269,6 +1198,7 @@ async def update_database( verbose_proxy_logger.info( f"Update Spend Logs DB failed to execute - {str(e)}" ) + raise e ### UPDATE KEY SPEND ### async def _update_team_db(): @@ -1282,41 +1212,10 @@ async def update_database( ) return if prisma_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await prisma_client.get_data( - team_id=team_id, table_name="team" + await prisma_client.db.litellm_teamtable.update( + where={"team_id": team_id}, + data={"spend": {"increment": response_cost}}, ) - verbose_proxy_logger.debug( - f"_update_team_db: existing spend: {existing_spend_obj}" - ) - if existing_spend_obj is None: - # the team does not exist in the db - return - verbose_proxy_logger.debug( - "team_id does not exist in db, not tracking spend for team" - ) - return - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - spend_per_model = getattr(existing_spend_obj, "model_spend", {}) - # track cost per model, for the given team - spend_per_model = existing_spend_obj.model_spend or {} - current_model = kwargs.get("model") - if current_model is not None and spend_per_model is not None: - if spend_per_model.get(current_model) is None: - spend_per_model[current_model] = response_cost - else: - spend_per_model[current_model] += response_cost - - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await prisma_client.update_data( - team_id=team_id, - data={"spend": new_spend, "model_spend": spend_per_model}, - table_name="team", - ) - elif custom_db_client is not None: # Fetch the existing cost for the given token existing_spend_obj = await custom_db_client.get_data( @@ -1346,17 +1245,88 @@ async def update_database( verbose_proxy_logger.info( f"Update Team DB failed to execute - {str(e)}" ) + raise e - asyncio.create_task(_update_user_db()) - asyncio.create_task(_update_key_db()) - asyncio.create_task(_update_team_db()) - asyncio.create_task(_insert_spend_log_to_db()) + tasks = [] + tasks.append(_update_user_db()) + tasks.append(_update_key_db()) + tasks.append(_update_team_db()) + tasks.append(_insert_spend_log_to_db()) + + await asyncio.gather(*tasks) verbose_proxy_logger.info("Successfully updated spend in all 3 tables") except Exception as e: verbose_proxy_logger.debug( f"Error updating Prisma database: {traceback.format_exc()}" ) - pass + + +async def update_cache( + token, + response_cost, +): + """ + Use this to update the cache with new user spend. + + Put any alerting logic in here. + """ + ### UPDATE KEY SPEND ### + # Fetch the existing cost for the given token + existing_spend_obj = await user_api_key_cache.async_get_cache(key=token) + verbose_proxy_logger.debug(f"_update_key_db: existing spend: {existing_spend_obj}") + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + ## CHECK IF USER PROJECTED SPEND > SOFT LIMIT + soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown + if ( + existing_spend_obj.soft_budget_cooldown == False + and existing_spend_obj.litellm_budget_table is not None + and ( + _is_projected_spend_over_limit( + current_spend=new_spend, + soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget, + ) + == True + ) + ): + key_alias = existing_spend_obj.key_alias + projected_spend, projected_exceeded_date = _get_projected_spend_over_limit( + current_spend=new_spend, + soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget, + ) + soft_limit = existing_spend_obj.litellm_budget_table.soft_budget + user_info = { + "key_alias": key_alias, + "projected_spend": projected_spend, + "projected_exceeded_date": projected_exceeded_date, + } + # alert user + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="projected_limit_exceeded", + user_info=user_info, + user_max_budget=soft_limit, + user_current_spend=new_spend, + ) + ) + # set cooldown on alert + soft_budget_cooldown = True + + if existing_spend_obj is None: + existing_team_spend = 0 + else: + existing_team_spend = existing_spend_obj.team_spend + # Calculate the new cost by adding the existing cost and response_cost + existing_spend_obj.team_spend = existing_team_spend + response_cost + + # Update the cost column for the given token + existing_spend_obj.spend = new_spend + user_api_key_cache.set_cache(key=token, value=existing_spend_obj) def run_ollama_serve(): @@ -7238,6 +7208,55 @@ async def get_routes(): return {"routes": routes} +## TEST ENDPOINT +# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)]) +# async def update_database_endpoint( +# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +# ): +# """ +# Test endpoint. DO NOT MERGE IN PROD. + +# Used for isolating and testing our prisma db update logic in high-traffic. +# """ +# try: +# request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" +# resp = litellm.ModelResponse( +# id=request_id, +# choices=[ +# litellm.Choices( +# finish_reason=None, +# index=0, +# message=litellm.Message( +# content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", +# role="assistant", +# ), +# ) +# ], +# model="gpt-35-turbo", # azure always has model written like this +# usage=litellm.Usage( +# prompt_tokens=210, completion_tokens=200, total_tokens=410 +# ), +# ) +# await _PROXY_track_cost_callback( +# kwargs={ +# "model": "chatgpt-v-2", +# "stream": False, +# "litellm_params": { +# "metadata": { +# "user_api_key": user_api_key_dict.token, +# "user_api_key_user_id": user_api_key_dict.user_id, +# } +# }, +# "response_cost": 0.00002, +# }, +# completion_response=resp, +# start_time=datetime.now(), +# end_time=datetime.now(), +# ) +# except Exception as e: +# raise e + + def _has_user_setup_sso(): """ Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables. From acc672a78fc135f5fe239f7c78ef2546bfec172f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 13 Mar 2024 17:04:51 -0700 Subject: [PATCH 2/3] fix(proxy_server.py): maintain support for model specific budgets --- litellm/proxy/proxy_server.py | 40 +++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e374686e0..4f291df64 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -658,23 +658,37 @@ async def user_api_key_auth( # Check 5. Token Model Spend is under Model budget max_budget_per_model = valid_token.model_max_budget - spend_per_model = valid_token.model_spend - if max_budget_per_model is not None and spend_per_model is not None: + if ( + max_budget_per_model is not None + and isinstance(max_budget_per_model, dict) + and len(max_budget_per_model) > 0 + ): current_model = request_data.get("model") - if current_model is not None: - current_model_spend = spend_per_model.get(current_model, None) - current_model_budget = max_budget_per_model.get(current_model, None) - + ## GET THE SPEND FOR THIS MODEL + twenty_eight_days_ago = datetime.now() - timedelta(days=28) + model_spend = await prisma_client.db.litellm_spendlogs.group_by( + by=["model"], + sum={"spend": True}, + where={ + "AND": [ + {"api_key": valid_token.token}, + {"startTime": {"gt": twenty_eight_days_ago}}, + {"model": current_model}, + ] + }, + ) + if len(model_spend) > 0: if ( - current_model_spend is not None - and current_model_budget is not None + model_spend[0]["model"] == model + and model_spend[0]["_sum"]["spend"] + >= max_budget_per_model["model"] ): - if current_model_spend > current_model_budget: - raise Exception( - f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}" - ) - + current_model_spend = model_spend[0]["_sum"]["spend"] + current_model_budget = max_budget_per_model["model"] + raise Exception( + f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}" + ) # Check 6. Token spend is under Team budget if ( valid_token.spend is not None From 1b807fa3f5321aebc69960c0f50eb5ee4051b811 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 13 Mar 2024 19:10:24 -0700 Subject: [PATCH 3/3] fix(proxy_server.py): fix key caching logic --- litellm/caching.py | 33 +++ litellm/proxy/_types.py | 2 + litellm/proxy/proxy_server.py | 249 +++++++++++++++------- litellm/proxy/utils.py | 1 - litellm/tests/test_key_generate_prisma.py | 4 +- 5 files changed, 214 insertions(+), 75 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 833da1238..f22606bd3 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -742,6 +742,39 @@ class DualCache(BaseCache): except Exception as e: traceback.print_exc() + async def async_get_cache(self, key, local_only: bool = False, **kwargs): + # Try to fetch from in-memory cache first + try: + print_verbose( + f"async get cache: cache key: {key}; local_only: {local_only}" + ) + result = None + if self.in_memory_cache is not None: + in_memory_result = await self.in_memory_cache.async_get_cache( + key, **kwargs + ) + + print_verbose(f"in_memory_result: {in_memory_result}") + if in_memory_result is not None: + result = in_memory_result + + if result is None and self.redis_cache is not None and local_only == False: + # If not found in in-memory cache, try fetching from Redis + redis_result = await self.redis_cache.async_get_cache(key, **kwargs) + + if redis_result is not None: + # Update in-memory cache with the value from Redis + await self.in_memory_cache.async_set_cache( + key, redis_result, **kwargs + ) + + result = redis_result + + print_verbose(f"get cache: cache result: {result}") + return result + except Exception as e: + traceback.print_exc() + def flush_cache(self): if self.in_memory_cache is not None: self.in_memory_cache.flush_cache() diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index fd85280dd..8a7efa1a1 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -535,6 +535,8 @@ class LiteLLM_VerificationToken(LiteLLMBase): permissions: Dict = {} model_spend: Dict = {} model_max_budget: Dict = {} + soft_budget_cooldown: bool = False + litellm_budget_table: Optional[dict] = None # hidden params used for parallel request limiting, not required to create a token user_id_rate_limits: Optional[dict] = None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4f291df64..4d2cae032 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -651,7 +651,7 @@ async def user_api_key_auth( ) ) - if valid_token.spend > valid_token.max_budget: + if valid_token.spend >= valid_token.max_budget: raise Exception( f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Token: {valid_token.max_budget}" ) @@ -678,14 +678,17 @@ async def user_api_key_auth( ] }, ) - if len(model_spend) > 0: + if ( + len(model_spend) > 0 + and max_budget_per_model.get(current_model, None) is not None + ): if ( - model_spend[0]["model"] == model + model_spend[0]["model"] == current_model and model_spend[0]["_sum"]["spend"] - >= max_budget_per_model["model"] + >= max_budget_per_model[current_model] ): current_model_spend = model_spend[0]["_sum"]["spend"] - current_model_budget = max_budget_per_model["model"] + current_model_budget = max_budget_per_model[current_model] raise Exception( f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}" ) @@ -742,15 +745,7 @@ async def user_api_key_auth( This makes the user row data accessible to pre-api call hooks. """ - if prisma_client is not None: - asyncio.create_task( - _cache_user_row( - user_id=valid_token.user_id, - cache=user_api_key_cache, - db=prisma_client, - ) - ) - elif custom_db_client is not None: + if custom_db_client is not None: asyncio.create_task( _cache_user_row( user_id=valid_token.user_id, @@ -1023,7 +1018,9 @@ async def _PROXY_track_cost_callback( end_time=end_time, ) - await update_cache(token=user_api_key, response_cost=response_cost) + await update_cache( + token=user_api_key, user_id=user_id, response_cost=response_cost + ) else: raise Exception("User API key missing from custom callback.") else: @@ -1072,15 +1069,54 @@ async def update_database( - Update that user's row - Update litellm-proxy-budget row (global proxy spend) """ - user_ids = [user_id, litellm_proxy_budget_name] + ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db + end_user_id = None + if isinstance(token, str) and token.startswith("sk-"): + hashed_token = hash_token(token=token) + else: + hashed_token = token + existing_token_obj = await user_api_key_cache.async_get_cache( + key=hashed_token + ) + existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) + if existing_token_obj.user_id != user_id: # an end-user id was passed in + end_user_id = user_id + user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name] data_list = [] try: if prisma_client is not None: # update user_ids = [user_id, litellm_proxy_budget_name] - await prisma_client.db.litellm_usertable.update( + ## do a group update for the user-id of the key + global proxy budget + await prisma_client.db.litellm_usertable.update_many( where={"user_id": {"in": user_ids}}, data={"spend": {"increment": response_cost}}, ) + if end_user_id is not None: + if existing_user_obj is None: + # if user does not exist in LiteLLM_UserTable, create a new user + existing_spend = 0 + max_user_budget = None + if litellm.max_user_budget is not None: + max_user_budget = litellm.max_user_budget + existing_user_obj = LiteLLM_UserTable( + user_id=end_user_id, + spend=0, + max_budget=max_user_budget, + user_email=None, + ) + else: + existing_user_obj.spend = ( + existing_user_obj.spend + response_cost + ) + + await prisma_client.db.litellm_usertable.upsert( + where={"user_id": end_user_id}, + data={ + "create": {**existing_user_obj.json(exclude_none=True)}, + "update": {"spend": {"increment": response_cost}}, + }, + ) + elif custom_db_client is not None: for id in user_ids: if id is None: @@ -1261,13 +1297,11 @@ async def update_database( ) raise e - tasks = [] - tasks.append(_update_user_db()) - tasks.append(_update_key_db()) - tasks.append(_update_team_db()) - tasks.append(_insert_spend_log_to_db()) + asyncio.create_task(_update_user_db()) + asyncio.create_task(_update_key_db()) + asyncio.create_task(_update_team_db()) + asyncio.create_task(_insert_spend_log_to_db()) - await asyncio.gather(*tasks) verbose_proxy_logger.info("Successfully updated spend in all 3 tables") except Exception as e: verbose_proxy_logger.debug( @@ -1277,6 +1311,7 @@ async def update_database( async def update_cache( token, + user_id, response_cost, ): """ @@ -1284,63 +1319,131 @@ async def update_cache( Put any alerting logic in here. """ - ### UPDATE KEY SPEND ### - # Fetch the existing cost for the given token - existing_spend_obj = await user_api_key_cache.async_get_cache(key=token) - verbose_proxy_logger.debug(f"_update_key_db: existing spend: {existing_spend_obj}") - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - ## CHECK IF USER PROJECTED SPEND > SOFT LIMIT - soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown - if ( - existing_spend_obj.soft_budget_cooldown == False - and existing_spend_obj.litellm_budget_table is not None - and ( - _is_projected_spend_over_limit( + ### UPDATE KEY SPEND ### + async def _update_key_cache(): + # Fetch the existing cost for the given token + if isinstance(token, str) and token.startswith("sk-"): + hashed_token = hash_token(token=token) + else: + hashed_token = token + existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token) + verbose_proxy_logger.debug( + f"_update_key_db: existing spend: {existing_spend_obj}" + ) + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + ## CHECK IF USER PROJECTED SPEND > SOFT LIMIT + soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown + if ( + existing_spend_obj.soft_budget_cooldown == False + and existing_spend_obj.litellm_budget_table is not None + and ( + _is_projected_spend_over_limit( + current_spend=new_spend, + soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget, + ) + == True + ) + ): + key_alias = existing_spend_obj.key_alias + projected_spend, projected_exceeded_date = _get_projected_spend_over_limit( current_spend=new_spend, soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget, ) - == True - ) - ): - key_alias = existing_spend_obj.key_alias - projected_spend, projected_exceeded_date = _get_projected_spend_over_limit( - current_spend=new_spend, - soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget, - ) - soft_limit = existing_spend_obj.litellm_budget_table.soft_budget - user_info = { - "key_alias": key_alias, - "projected_spend": projected_spend, - "projected_exceeded_date": projected_exceeded_date, - } - # alert user - asyncio.create_task( - proxy_logging_obj.budget_alerts( - type="projected_limit_exceeded", - user_info=user_info, - user_max_budget=soft_limit, - user_current_spend=new_spend, + soft_limit = existing_spend_obj.litellm_budget_table.soft_budget + user_info = { + "key_alias": key_alias, + "projected_spend": projected_spend, + "projected_exceeded_date": projected_exceeded_date, + } + # alert user + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="projected_limit_exceeded", + user_info=user_info, + user_max_budget=soft_limit, + user_current_spend=new_spend, + ) ) - ) - # set cooldown on alert - soft_budget_cooldown = True + # set cooldown on alert + soft_budget_cooldown = True - if existing_spend_obj is None: - existing_team_spend = 0 - else: - existing_team_spend = existing_spend_obj.team_spend - # Calculate the new cost by adding the existing cost and response_cost - existing_spend_obj.team_spend = existing_team_spend + response_cost + if ( + existing_spend_obj is not None + and getattr(existing_spend_obj, "team_spend", None) is not None + ): + existing_team_spend = existing_spend_obj.team_spend + # Calculate the new cost by adding the existing cost and response_cost + existing_spend_obj.team_spend = existing_team_spend + response_cost - # Update the cost column for the given token - existing_spend_obj.spend = new_spend - user_api_key_cache.set_cache(key=token, value=existing_spend_obj) + # Update the cost column for the given token + existing_spend_obj.spend = new_spend + user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj) + + async def _update_user_cache(): + ## UPDATE CACHE FOR USER ID + GLOBAL PROXY + end_user_id = None + if isinstance(token, str) and token.startswith("sk-"): + hashed_token = hash_token(token=token) + else: + hashed_token = token + existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token) + existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) + if existing_token_obj.user_id != user_id: # an end-user id was passed in + end_user_id = user_id + user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id] + + try: + for _id in user_ids: + # Fetch the existing cost for the given user + existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id) + if existing_spend_obj is None: + # if user does not exist in LiteLLM_UserTable, create a new user + existing_spend = 0 + max_user_budget = None + if litellm.max_user_budget is not None: + max_user_budget = litellm.max_user_budget + existing_spend_obj = LiteLLM_UserTable( + user_id=_id, + spend=0, + max_budget=max_user_budget, + user_email=None, + ) + verbose_proxy_logger.debug( + f"_update_user_db: existing spend: {existing_spend_obj}" + ) + if existing_spend_obj is None: + existing_spend = 0 + else: + if isinstance(existing_spend_obj, dict): + existing_spend = existing_spend_obj["spend"] + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + # Update the cost column for the given user + if isinstance(existing_spend_obj, dict): + existing_spend_obj["spend"] = new_spend + user_api_key_cache.set_cache(key=_id, value=existing_spend_obj) + else: + existing_spend_obj.spend = new_spend + user_api_key_cache.set_cache( + key=_id, value=existing_spend_obj.json() + ) + except Exception as e: + verbose_proxy_logger.debug( + f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}" + ) + + asyncio.create_task(_update_key_cache()) + asyncio.create_task(_update_user_cache()) def run_ollama_serve(): diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 42ae6a378..57381bac1 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1596,7 +1596,6 @@ async def _cache_user_row( Check if a user_id exists in cache, if not retrieve it. """ - print_verbose(f"Prisma: _cache_user_row, user_id: {user_id}") cache_key = f"{user_id}_user_api_key_user_id" response = cache.get_cache(key=cache_key) if response is None: # Cache miss diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 62f6c38a9..151781beb 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -318,7 +318,7 @@ def test_call_with_user_over_budget(prisma_client): def test_call_with_end_user_over_budget(prisma_client): - # Test if a user passed to /chat/completions is tracked & fails whe they cross their budget + # Test if a user passed to /chat/completions is tracked & fails when they cross their budget # we only check this when litellm.max_user_budget is set import random @@ -339,6 +339,8 @@ def test_call_with_end_user_over_budget(prisma_client): request = Request(scope={"type": "http"}) request._url = URL(url="/chat/completions") + result = await user_api_key_auth(request=request, api_key=bearer_token) + async def return_body(): return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}' # return string as bytes