diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py new file mode 100644 index 0000000000..d5f2ef5228 --- /dev/null +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -0,0 +1,233 @@ +import asyncio +import os +import traceback +from datetime import datetime +from typing import Any, Optional, Union + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import LiteLLM_UserTable, SpendLogsPayload +from litellm.proxy.proxy_server import hash_token +from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload +from litellm.proxy.utils import PrismaClient, ProxyUpdateSpend + + +class DBSpendUpdateWriter: + + @staticmethod + async def update_database( # noqa: PLR0915 + # LiteLLM management object fields + token: Optional[str], + user_id: Optional[str], + end_user_id: Optional[str], + team_id: Optional[str], + org_id: Optional[str], + # Completion object fields + kwargs: Optional[dict], + completion_response: Optional[Union[litellm.ModelResponse, Any, Exception]], + start_time: Optional[datetime], + end_time: Optional[datetime], + response_cost: Optional[float], + ): + from litellm.proxy.proxy_server import ( + disable_spend_logs, + litellm_proxy_budget_name, + prisma_client, + user_api_key_cache, + ) + + try: + verbose_proxy_logger.debug( + f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" + ) + if ProxyUpdateSpend.disable_spend_updates() is True: + return + if token is not None and isinstance(token, str) and token.startswith("sk-"): + hashed_token = hash_token(token=token) + else: + hashed_token = token + + ### UPDATE USER SPEND ### + async def _update_user_db(): + """ + - Update that user's row + - Update litellm-proxy-budget row (global proxy spend) + """ + ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db + existing_user_obj = await user_api_key_cache.async_get_cache( + key=user_id + ) + if existing_user_obj is not None and isinstance( + existing_user_obj, dict + ): + existing_user_obj = LiteLLM_UserTable(**existing_user_obj) + try: + if prisma_client is not None: # update + user_ids = [user_id] + if ( + litellm.max_budget > 0 + ): # track global proxy budget, if user set max budget + user_ids.append(litellm_proxy_budget_name) + ### KEY CHANGE ### + for _id in user_ids: + if _id is not None: + prisma_client.user_list_transactons[_id] = ( + response_cost + + prisma_client.user_list_transactons.get(_id, 0) + ) + if end_user_id is not None: + prisma_client.end_user_list_transactons[end_user_id] = ( + response_cost + + prisma_client.end_user_list_transactons.get( + end_user_id, 0 + ) + ) + except Exception as e: + verbose_proxy_logger.info( + "\033[91m" + + f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}" + ) + + ### UPDATE KEY SPEND ### + async def _update_key_db(): + try: + verbose_proxy_logger.debug( + f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}." + ) + if hashed_token is None: + return + if prisma_client is not None: + prisma_client.key_list_transactons[hashed_token] = ( + response_cost + + prisma_client.key_list_transactons.get(hashed_token, 0) + ) + except Exception as e: + verbose_proxy_logger.exception( + f"Update Key DB Call failed to execute - {str(e)}" + ) + raise e + + ### UPDATE SPEND LOGS ### + async def _insert_spend_log_to_db(): + try: + if prisma_client: + payload = get_logging_payload( + kwargs=kwargs, + response_obj=completion_response, + start_time=start_time, + end_time=end_time, + ) + payload["spend"] = response_cost or 0.0 + DBSpendUpdateWriter._set_spend_logs_payload( + payload=payload, + spend_logs_url=os.getenv("SPEND_LOGS_URL"), + prisma_client=prisma_client, + ) + except Exception as e: + verbose_proxy_logger.debug( + f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" + ) + raise e + + ### UPDATE TEAM SPEND ### + async def _update_team_db(): + try: + verbose_proxy_logger.debug( + f"adding spend to team db. Response cost: {response_cost}. team_id: {team_id}." + ) + if team_id is None: + verbose_proxy_logger.debug( + "track_cost_callback: team_id is None. Not tracking spend for team" + ) + return + if prisma_client is not None: + prisma_client.team_list_transactons[team_id] = ( + response_cost + + prisma_client.team_list_transactons.get(team_id, 0) + ) + + try: + # Track spend of the team member within this team + # key is "team_id::::user_id::" + team_member_key = f"team_id::{team_id}::user_id::{user_id}" + prisma_client.team_member_list_transactons[ + team_member_key + ] = ( + response_cost + + prisma_client.team_member_list_transactons.get( + team_member_key, 0 + ) + ) + except Exception: + pass + except Exception as e: + verbose_proxy_logger.info( + f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" + ) + raise e + + ### UPDATE ORG SPEND ### + async def _update_org_db(): + try: + verbose_proxy_logger.debug( + "adding spend to org db. Response cost: {}. org_id: {}.".format( + response_cost, org_id + ) + ) + if org_id is None: + verbose_proxy_logger.debug( + "track_cost_callback: org_id is None. Not tracking spend for org" + ) + return + if prisma_client is not None: + prisma_client.org_list_transactons[org_id] = ( + response_cost + + prisma_client.org_list_transactons.get(org_id, 0) + ) + except Exception as e: + verbose_proxy_logger.info( + f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}" + ) + raise e + + asyncio.create_task(_update_user_db()) + asyncio.create_task(_update_key_db()) + asyncio.create_task(_update_team_db()) + asyncio.create_task(_update_org_db()) + if disable_spend_logs is False: + await _insert_spend_log_to_db() + else: + verbose_proxy_logger.info( + "disable_spend_logs=True. Skipping writing spend logs to db. Other spend updates - Key/User/Team table will still occur." + ) + + verbose_proxy_logger.debug("Runs spend update on all tables") + except Exception: + verbose_proxy_logger.debug( + f"Error updating Prisma database: {traceback.format_exc()}" + ) + + @staticmethod + def _set_spend_logs_payload( + payload: Union[dict, SpendLogsPayload], + prisma_client: PrismaClient, + spend_logs_url: Optional[str] = None, + ) -> PrismaClient: + verbose_proxy_logger.info( + "Writing spend log to db - request_id: {}, spend: {}".format( + payload.get("request_id"), payload.get("spend") + ) + ) + if prisma_client is not None and spend_logs_url is not None: + if isinstance(payload["startTime"], datetime): + payload["startTime"] = payload["startTime"].isoformat() + if isinstance(payload["endTime"], datetime): + payload["endTime"] = payload["endTime"].isoformat() + prisma_client.spend_log_transactions.append(payload) + elif prisma_client is not None: + prisma_client.spend_log_transactions.append(payload) + + prisma_client.add_spend_log_transaction_to_daily_user_transaction( + payload.copy() + ) + return prisma_client diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index e8a947329d..f205b0146f 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -13,6 +13,7 @@ from litellm.litellm_core_utils.core_helpers import ( from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.auth_checks import log_db_metrics +from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter from litellm.proxy.utils import ProxyUpdateSpend from litellm.types.utils import ( StandardLoggingPayload, @@ -33,8 +34,6 @@ class _ProxyDBLogger(CustomLogger): original_exception: Exception, user_api_key_dict: UserAPIKeyAuth, ): - from litellm.proxy.proxy_server import update_database - if _ProxyDBLogger._should_track_errors_in_db() is False: return @@ -67,7 +66,7 @@ class _ProxyDBLogger(CustomLogger): request_data.get("proxy_server_request") or {} ) request_data["litellm_params"]["metadata"] = existing_metadata - await update_database( + await DBSpendUpdateWriter.update_database( token=user_api_key_dict.api_key, response_cost=0.0, user_id=user_api_key_dict.user_id, @@ -94,7 +93,6 @@ class _ProxyDBLogger(CustomLogger): prisma_client, proxy_logging_obj, update_cache, - update_database, ) verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") @@ -138,7 +136,7 @@ class _ProxyDBLogger(CustomLogger): end_user_id=end_user_id, ): ## UPDATE DATABASE - await update_database( + await DBSpendUpdateWriter.update_database( token=user_api_key, response_cost=response_cost, user_id=user_id, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5d7e92fd73..6a2da7d83b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -897,211 +897,6 @@ def cost_tracking(): litellm.logging_callback_manager.add_litellm_callback(_ProxyDBLogger()) -def _set_spend_logs_payload( - payload: Union[dict, SpendLogsPayload], - prisma_client: PrismaClient, - spend_logs_url: Optional[str] = None, -): - verbose_proxy_logger.info( - "Writing spend log to db - request_id: {}, spend: {}".format( - payload.get("request_id"), payload.get("spend") - ) - ) - if prisma_client is not None and spend_logs_url is not None: - if isinstance(payload["startTime"], datetime): - payload["startTime"] = payload["startTime"].isoformat() - if isinstance(payload["endTime"], datetime): - payload["endTime"] = payload["endTime"].isoformat() - prisma_client.spend_log_transactions.append(payload) - elif prisma_client is not None: - prisma_client.spend_log_transactions.append(payload) - - prisma_client.add_spend_log_transaction_to_daily_user_transaction(payload.copy()) - return prisma_client - - -async def update_database( # noqa: PLR0915 - token, - response_cost, - user_id=None, - end_user_id=None, - team_id=None, - kwargs=None, - completion_response=None, - start_time=None, - end_time=None, - org_id=None, -): - try: - global prisma_client - verbose_proxy_logger.debug( - f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" - ) - if ProxyUpdateSpend.disable_spend_updates() is True: - return - if token is not None and isinstance(token, str) and token.startswith("sk-"): - hashed_token = hash_token(token=token) - else: - hashed_token = token - - ### UPDATE USER SPEND ### - async def _update_user_db(): - """ - - Update that user's row - - Update litellm-proxy-budget row (global proxy spend) - """ - ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db - existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) - if existing_user_obj is not None and isinstance(existing_user_obj, dict): - existing_user_obj = LiteLLM_UserTable(**existing_user_obj) - try: - if prisma_client is not None: # update - user_ids = [user_id] - if ( - litellm.max_budget > 0 - ): # track global proxy budget, if user set max budget - user_ids.append(litellm_proxy_budget_name) - ### KEY CHANGE ### - for _id in user_ids: - if _id is not None: - prisma_client.user_list_transactons[_id] = ( - response_cost - + prisma_client.user_list_transactons.get(_id, 0) - ) - if end_user_id is not None: - prisma_client.end_user_list_transactons[end_user_id] = ( - response_cost - + prisma_client.end_user_list_transactons.get( - end_user_id, 0 - ) - ) - except Exception as e: - verbose_proxy_logger.info( - "\033[91m" - + f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}" - ) - - ### UPDATE KEY SPEND ### - async def _update_key_db(): - try: - verbose_proxy_logger.debug( - f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}." - ) - if hashed_token is None: - return - if prisma_client is not None: - prisma_client.key_list_transactons[hashed_token] = ( - response_cost - + prisma_client.key_list_transactons.get(hashed_token, 0) - ) - except Exception as e: - verbose_proxy_logger.exception( - f"Update Key DB Call failed to execute - {str(e)}" - ) - raise e - - ### UPDATE SPEND LOGS ### - async def _insert_spend_log_to_db(): - try: - global prisma_client - if prisma_client is not None: - # Helper to generate payload to log - payload = get_logging_payload( - kwargs=kwargs, - response_obj=completion_response, - start_time=start_time, - end_time=end_time, - ) - payload["spend"] = response_cost - prisma_client = _set_spend_logs_payload( - payload=payload, - spend_logs_url=os.getenv("SPEND_LOGS_URL"), - prisma_client=prisma_client, - ) - except Exception as e: - verbose_proxy_logger.debug( - f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" - ) - raise e - - ### UPDATE TEAM SPEND ### - async def _update_team_db(): - try: - verbose_proxy_logger.debug( - f"adding spend to team db. Response cost: {response_cost}. team_id: {team_id}." - ) - if team_id is None: - verbose_proxy_logger.debug( - "track_cost_callback: team_id is None. Not tracking spend for team" - ) - return - if prisma_client is not None: - prisma_client.team_list_transactons[team_id] = ( - response_cost - + prisma_client.team_list_transactons.get(team_id, 0) - ) - - try: - # Track spend of the team member within this team - # key is "team_id::::user_id::" - team_member_key = f"team_id::{team_id}::user_id::{user_id}" - prisma_client.team_member_list_transactons[team_member_key] = ( - response_cost - + prisma_client.team_member_list_transactons.get( - team_member_key, 0 - ) - ) - except Exception: - pass - except Exception as e: - verbose_proxy_logger.info( - f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" - ) - raise e - - ### UPDATE ORG SPEND ### - async def _update_org_db(): - try: - verbose_proxy_logger.debug( - "adding spend to org db. Response cost: {}. org_id: {}.".format( - response_cost, org_id - ) - ) - if org_id is None: - verbose_proxy_logger.debug( - "track_cost_callback: org_id is None. Not tracking spend for org" - ) - return - if prisma_client is not None: - prisma_client.org_list_transactons[org_id] = ( - response_cost - + prisma_client.org_list_transactons.get(org_id, 0) - ) - except Exception as e: - verbose_proxy_logger.info( - f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}" - ) - raise e - - asyncio.create_task(_update_user_db()) - asyncio.create_task(_update_key_db()) - asyncio.create_task(_update_team_db()) - asyncio.create_task(_update_org_db()) - # asyncio.create_task(_insert_spend_log_to_db()) - if disable_spend_logs is False: - await _insert_spend_log_to_db() - else: - verbose_proxy_logger.info( - "disable_spend_logs=True. Skipping writing spend logs to db. Other spend updates - Key/User/Team table will still occur." - ) - - verbose_proxy_logger.debug("Runs spend update on all tables") - except Exception: - verbose_proxy_logger.debug( - f"Error updating Prisma database: {traceback.format_exc()}" - ) - - async def update_cache( # noqa: PLR0915 token: Optional[str], user_id: Optional[str], @@ -3294,14 +3089,14 @@ class ProxyStartupEvent: prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj ) - ### GET STORED CREDENTIALS ### - scheduler.add_job( - proxy_config.get_credentials, - "interval", - seconds=10, - args=[prisma_client], - ) - await proxy_config.get_credentials(prisma_client=prisma_client) + ### GET STORED CREDENTIALS ### + scheduler.add_job( + proxy_config.get_credentials, + "interval", + seconds=1, + args=[prisma_client], + ) + await proxy_config.get_credentials(prisma_client=prisma_client) if ( proxy_logging_obj is not None and proxy_logging_obj.slack_alerting_instance.alerting is not None diff --git a/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py b/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py index 1e3b22ae2d..8850436329 100644 --- a/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py +++ b/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py @@ -47,7 +47,8 @@ async def test_async_post_call_failure_hook(): # Mock update_database function with patch( - "litellm.proxy.proxy_server.update_database", new_callable=AsyncMock + "litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter.update_database", + new_callable=AsyncMock, ) as mock_update_database: # Call the method await logger.async_post_call_failure_hook( diff --git a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py index 535f5bf019..129be6d754 100644 --- a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py +++ b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, patch, AsyncMock import pytest from fastapi import Request from litellm.proxy.utils import _get_redoc_url, _get_docs_url +from datetime import datetime sys.path.insert(0, os.path.abspath("../..")) import litellm @@ -22,16 +23,20 @@ async def test_disable_spend_logs(): with patch("litellm.proxy.proxy_server.disable_spend_logs", True), patch( "litellm.proxy.proxy_server.prisma_client", mock_prisma_client ): - from litellm.proxy.proxy_server import update_database + from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter # Call update_database with disable_spend_logs=True - await update_database( + await DBSpendUpdateWriter.update_database( token="fake-token", response_cost=0.1, user_id="user123", completion_response=None, - start_time="2024-01-01", - end_time="2024-01-01", + start_time=datetime.now(), + end_time=datetime.now(), + end_user_id="end_user_id", + team_id="team_id", + org_id="org_id", + kwargs={}, ) # Verify no spend logs were added assert len(mock_prisma_client.spend_log_transactions) == 0