From d7601a4844b3869e248763f63e2c63582c969b8e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 2 Apr 2024 18:46:17 -0700 Subject: [PATCH] perf(proxy_server.py): batch write spend logs reduces prisma client errors, by batch writing spend logs - max 1k logs at a time --- litellm/proxy/_new_secret_config.yaml | 12 +- litellm/proxy/proxy_server.py | 8 +- litellm/proxy/utils.py | 159 ++++++++++++++++++---- litellm/tests/test_key_generate_prisma.py | 47 +++++-- tests/test_spend_logs.py | 2 +- 5 files changed, 178 insertions(+), 50 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 351112474..db62e5b5f 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,18 +3,18 @@ model_list: litellm_params: model: openai/my-fake-model api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ + api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ - model_name: gpt-instruct litellm_params: model: gpt-3.5-turbo-instruct # api_key: my-fake-key # api_base: https://exampleopenaiendpoint-production.up.railway.app/ -litellm_settings: - drop_params: True - max_budget: 800021 - budget_duration: 30d - # cache: true +# litellm_settings: +# drop_params: True +# max_budget: 800021 +# budget_duration: 30d +# # cache: true general_settings: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e8c0a1a28..24a0f5fab 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1468,8 +1468,8 @@ async def update_database( payload["spend"] = response_cost if ( - os.getenv("SPEND_LOGS_URL", None) is not None - and prisma_client is not None + prisma_client is not None + and os.getenv("SPEND_LOGS_URL", None) is not None ): if isinstance(payload["startTime"], datetime): payload["startTime"] = payload["startTime"].isoformat() @@ -1477,7 +1477,7 @@ async def update_database( payload["endTime"] = payload["endTime"].isoformat() prisma_client.spend_log_transactons.append(payload) elif prisma_client is not None: - await prisma_client.insert_data(data=payload, table_name="spend") + prisma_client.spend_log_transactions.append(payload) except Exception as e: verbose_proxy_logger.debug( f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" @@ -2966,7 +2966,7 @@ async def startup_event(): update_spend, "interval", seconds=batch_writing_interval, - args=[prisma_client, db_writer_client], + args=[prisma_client, db_writer_client, proxy_logging_obj], ) scheduler.start() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 708f77aa8..17f70b322 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -528,7 +528,7 @@ class PrismaClient: end_user_list_transactons: dict = {} key_list_transactons: dict = {} team_list_transactons: dict = {} - spend_log_transactons: List = [] + spend_log_transactions: List = [] def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): print_verbose( @@ -1906,7 +1906,9 @@ async def reset_budget(prisma_client: PrismaClient): async def update_spend( - prisma_client: PrismaClient, db_writer_client: Optional[HTTPHandler] + prisma_client: PrismaClient, + db_writer_client: Optional[HTTPHandler], + proxy_logging_obj: ProxyLogging, ): """ Batch write updates to db. @@ -1920,7 +1922,6 @@ async def update_spend( spend_logs: list, """ n_retry_times = 3 - verbose_proxy_logger.debug("INSIDE UPDATE SPEND") ### UPDATE USER TABLE ### if len(prisma_client.user_list_transactons.keys()) > 0: for i in range(n_retry_times + 1): @@ -1940,12 +1941,25 @@ async def update_spend( prisma_client.user_list_transactons = ( {} ) # Clear the remaining transactions after processing all batches in the loop. + break except httpx.ReadTimeout: if i >= n_retry_times: # If we've reached the maximum number of retries raise # Re-raise the last exception # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: + import traceback + + error_msg = ( + f"LiteLLM Prisma Client Exception - update user spend: {str(e)}" + ) + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, traceback_str=error_traceback + ) + ) raise e ### UPDATE END-USER TABLE ### @@ -1973,12 +1987,25 @@ async def update_spend( prisma_client.end_user_list_transactons = ( {} ) # Clear the remaining transactions after processing all batches in the loop. + break except httpx.ReadTimeout: if i >= n_retry_times: # If we've reached the maximum number of retries raise # Re-raise the last exception # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: + import traceback + + error_msg = ( + f"LiteLLM Prisma Client Exception - update end-user spend: {str(e)}" + ) + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, traceback_str=error_traceback + ) + ) raise e ### UPDATE KEY TABLE ### @@ -2000,12 +2027,25 @@ async def update_spend( prisma_client.key_list_transactons = ( {} ) # Clear the remaining transactions after processing all batches in the loop. + break except httpx.ReadTimeout: if i >= n_retry_times: # If we've reached the maximum number of retries raise # Re-raise the last exception # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: + import traceback + + error_msg = ( + f"LiteLLM Prisma Client Exception - update key spend: {str(e)}" + ) + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, traceback_str=error_traceback + ) + ) raise e ### UPDATE TEAM TABLE ### @@ -2037,39 +2077,108 @@ async def update_spend( prisma_client.team_list_transactons = ( {} ) # Clear the remaining transactions after processing all batches in the loop. + break except httpx.ReadTimeout: if i >= n_retry_times: # If we've reached the maximum number of retries raise # Re-raise the last exception # Optionally, sleep for a bit before retrying await asyncio.sleep(2**i) # Exponential backoff except Exception as e: + import traceback + + error_msg = ( + f"LiteLLM Prisma Client Exception - update team spend: {str(e)}" + ) + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, traceback_str=error_traceback + ) + ) raise e ### UPDATE SPEND LOGS ### - base_url = os.getenv("SPEND_LOGS_URL", None) - if ( - len(prisma_client.spend_log_transactons) > 0 - and base_url is not None - and db_writer_client is not None - ): - if not base_url.endswith("/"): - base_url += "/" - verbose_proxy_logger.debug("base_url: {}".format(base_url)) - response = await db_writer_client.post( - url=base_url + "spend/update", - data=json.dumps(prisma_client.spend_log_transactons), # type: ignore - headers={"Content-Type": "application/json"}, - ) - if response.status_code == 200: - prisma_client.spend_log_transactons = [] + verbose_proxy_logger.debug( + "Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions)) + ) + BATCH_SIZE = 100 # Preferred size of each batch to write to the database + MAX_LOGS_PER_INTERVAL = 1000 # Maximum number of logs to flush in a single interval -# async def monitor_spend_list(prisma_client: PrismaClient): -# """ -# Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db -# """ -# if len(prisma_client.user_list_transactons) > 10000: -# await update_spend(prisma_client=prisma_client) + if len(prisma_client.spend_log_transactions) > 0: + for _ in range(n_retry_times + 1): + try: + base_url = os.getenv("SPEND_LOGS_URL", None) + ## WRITE TO SEPARATE SERVER ## + if ( + len(prisma_client.spend_log_transactions) > 0 + and base_url is not None + and db_writer_client is not None + ): + if not base_url.endswith("/"): + base_url += "/" + verbose_proxy_logger.debug("base_url: {}".format(base_url)) + response = await db_writer_client.post( + url=base_url + "spend/update", + data=json.dumps(prisma_client.spend_log_transactions), # type: ignore + headers={"Content-Type": "application/json"}, + ) + if response.status_code == 200: + prisma_client.spend_log_transactions = [] + else: ## (default) WRITE TO DB ## + logs_to_process = prisma_client.spend_log_transactions[ + :MAX_LOGS_PER_INTERVAL + ] + for i in range(0, len(logs_to_process), BATCH_SIZE): + # Create sublist for current batch, ensuring it doesn't exceed the BATCH_SIZE + batch = logs_to_process[i : i + BATCH_SIZE] + + # Convert datetime strings to Date objects + batch_with_dates = [ + prisma_client.jsonify_object( + { + **entry, + } + ) + for entry in batch + ] + + await prisma_client.db.litellm_spendlogs.create_many( + data=batch_with_dates, skip_duplicates=True # type: ignore + ) + + verbose_proxy_logger.debug( + f"Flushed {len(batch)} logs to the DB." + ) + # Remove the processed logs from spend_logs + prisma_client.spend_log_transactions = ( + prisma_client.spend_log_transactions[len(logs_to_process) :] + ) + + verbose_proxy_logger.debug( + f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}" + ) + break + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + import traceback + + error_msg = ( + f"LiteLLM Prisma Client Exception - update spend logs: {str(e)}" + ) + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, traceback_str=error_traceback + ) + ) + raise e async def _read_request_body(request): diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index aa1fbc2fc..357bc2817 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -51,7 +51,7 @@ from litellm.proxy.proxy_server import ( user_info, info_key_fn, ) -from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token +from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend from litellm._logging import verbose_proxy_logger verbose_proxy_logger.setLevel(level=logging.DEBUG) @@ -1141,9 +1141,9 @@ def test_call_with_key_over_budget(prisma_client): from litellm.caching import Cache litellm.cache = Cache() - import time + import time, uuid - request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" + request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" resp = ModelResponse( id=request_id, @@ -1176,7 +1176,11 @@ def test_call_with_key_over_budget(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - await asyncio.sleep(10) + await update_spend( + prisma_client=prisma_client, + db_writer_client=None, + proxy_logging_obj=proxy_logging_obj, + ) # test spend_log was written and we can read it spend_logs = await view_spend_logs(request_id=request_id) @@ -1202,7 +1206,10 @@ def test_call_with_key_over_budget(prisma_client): except Exception as e: # print(f"Error - {str(e)}") traceback.print_exc() - error_detail = e.message + if hasattr(e, "message"): + error_detail = e.message + else: + error_detail = str(e) assert "Authentication Error, ExceededTokenBudget:" in error_detail print(vars(e)) @@ -1251,9 +1258,9 @@ def test_call_with_key_over_model_budget(prisma_client): from litellm.caching import Cache litellm.cache = Cache() - import time + import time, uuid - request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" + request_id = f"chatcmpl-{uuid.uuid4()}" resp = ModelResponse( id=request_id, @@ -1286,7 +1293,11 @@ def test_call_with_key_over_model_budget(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - await asyncio.sleep(10) + await update_spend( + prisma_client=prisma_client, + db_writer_client=None, + proxy_logging_obj=proxy_logging_obj, + ) # test spend_log was written and we can read it spend_logs = await view_spend_logs(request_id=request_id) @@ -1344,9 +1355,9 @@ async def test_call_with_key_never_over_budget(prisma_client): _PROXY_track_cost_callback as track_cost_callback, ) from litellm import ModelResponse, Choices, Message, Usage - import time + import time, uuid - request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" + request_id = f"chatcmpl-{uuid.uuid4()}" resp = ModelResponse( id=request_id, @@ -1381,7 +1392,11 @@ async def test_call_with_key_never_over_budget(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - await asyncio.sleep(5) + await update_spend( + prisma_client=prisma_client, + db_writer_client=None, + proxy_logging_obj=proxy_logging_obj, + ) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) @@ -1421,9 +1436,9 @@ async def test_call_with_key_over_budget_stream(prisma_client): _PROXY_track_cost_callback as track_cost_callback, ) from litellm import ModelResponse, Choices, Message, Usage - import time + import time, uuid - request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" + request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" resp = ModelResponse( id=request_id, choices=[ @@ -1457,7 +1472,11 @@ async def test_call_with_key_over_budget_stream(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - await asyncio.sleep(5) + await update_spend( + prisma_client=prisma_client, + db_writer_client=None, + proxy_logging_obj=proxy_logging_obj, + ) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) diff --git a/tests/test_spend_logs.py b/tests/test_spend_logs.py index c6866317d..477fdb86f 100644 --- a/tests/test_spend_logs.py +++ b/tests/test_spend_logs.py @@ -109,7 +109,7 @@ async def test_spend_logs(): key_gen = await generate_key(session=session) key = key_gen["key"] response = await chat_completion(session=session, key=key) - await asyncio.sleep(5) + await asyncio.sleep(20) await get_spend_logs(session=session, request_id=response["id"])