From ad43138f282a2fae9ef8ad7f2c27850fdfbd024c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 27 Mar 2024 19:43:15 -0700 Subject: [PATCH] fix(proxy_server.py): fix budget add logic to accurately log who created it --- litellm/proxy/proxy_server.py | 62 +++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 69f75115a..04f35a02b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1008,7 +1008,9 @@ async def user_api_key_auth( # Do something if the current route starts with any of the allowed routes pass else: - if _is_user_proxy_admin(user_id_information): + if user_id_information is not None and _is_user_proxy_admin( + user_id_information + ): return UserAPIKeyAuth( api_key=api_key, user_role="proxy_admin", **valid_token_dict ) @@ -2262,9 +2264,7 @@ async def generate_key_helper_fn( spend: float, key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key key_budget_duration: Optional[str] = None, - key_soft_budget: Optional[ - float - ] = None, # key_soft_budget is used to Budget Per key + budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable soft_budget: Optional[ float ] = None, # soft_budget is used to set soft Budgets Per user @@ -2287,7 +2287,7 @@ async def generate_key_helper_fn( model_max_budget: Optional[dict] = {}, table_name: Optional[Literal["key", "user"]] = None, ): - global prisma_client, custom_db_client, user_api_key_cache + global prisma_client, custom_db_client, user_api_key_cache, litellm_proxy_admin_name if prisma_client is None and custom_db_client is None: raise Exception( @@ -2325,25 +2325,6 @@ async def generate_key_helper_fn( rpm_limit = rpm_limit allowed_cache_controls = allowed_cache_controls - # TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable - _budget_id = None - if prisma_client is not None and key_soft_budget is not None: - # create the Budget Row for the LiteLLM Verification Token - budget_row = LiteLLM_BudgetTable( - soft_budget=key_soft_budget, - model_max_budget=model_max_budget or {}, - ) - new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) - - _budget = await prisma_client.db.litellm_budgettable.create( - data={ - **new_budget, # type: ignore - "created_by": user_id, - "updated_by": user_id, - } - ) - _budget_id = getattr(_budget, "budget_id", None) - try: # Create a new verification token (you may want to enhance this logic based on your needs) user_data = { @@ -2381,7 +2362,7 @@ async def generate_key_helper_fn( "allowed_cache_controls": allowed_cache_controls, "permissions": permissions_json, "model_max_budget": model_max_budget_json, - "budget_id": _budget_id, + "budget_id": budget_id, } if ( general_settings.get("allow_user_auth", False) == True @@ -2461,7 +2442,7 @@ async def generate_key_helper_fn( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) # Add budget related info in key_data - this ensures it's returned - key_data["soft_budget"] = key_soft_budget + key_data["budget_id"] = budget_id return key_data @@ -3941,6 +3922,7 @@ async def moderations( ) async def generate_key_fn( data: GenerateKeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), Authorization: Optional[str] = Header(None), ): """ @@ -4044,18 +4026,42 @@ async def generate_key_fn( data, key, litellm.upperbound_key_generate_params[key] ) + # TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable + _budget_id = None + if prisma_client is not None and data.soft_budget is not None: + # create the Budget Row for the LiteLLM Verification Token + budget_row = LiteLLM_BudgetTable( + soft_budget=data.soft_budget, + model_max_budget=data.model_max_budget or {}, + ) + new_budget = prisma_client.jsonify_object( + budget_row.json(exclude_none=True) + ) + + _budget = await prisma_client.db.litellm_budgettable.create( + data={ + **new_budget, # type: ignore + "created_by": user_api_key_dict.user_id, + "updated_by": user_api_key_dict.user_id, + } + ) + _budget_id = getattr(_budget, "budget_id", None) data_json = data.json() # type: ignore # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users if "max_budget" in data_json: data_json["key_max_budget"] = data_json.pop("max_budget", None) - if "soft_budget" in data_json: - data_json["key_soft_budget"] = data_json.pop("soft_budget", None) + if _budget_id is not None: + data_json["budget_id"] = _budget_id if "budget_duration" in data_json: data_json["key_budget_duration"] = data_json.pop("budget_duration", None) response = await generate_key_helper_fn(**data_json, table_name="key") + + response["soft_budget"] = ( + data.soft_budget + ) # include the user-input soft budget in the response return GenerateKeyResponse(**response) except Exception as e: traceback.print_exc()