perf(proxy_server.py): batch write spend logs

reduces prisma client errors, by batch writing spend logs - max 1k logs at a time
This commit is contained in:
Krrish Dholakia 2024-04-02 18:46:17 -07:00
parent c35b4c9b80
commit d7601a4844
5 changed files with 178 additions and 50 deletions

View file

@ -3,18 +3,18 @@ model_list:
litellm_params: litellm_params:
model: openai/my-fake-model model: openai/my-fake-model
api_key: my-fake-key api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
- model_name: gpt-instruct - model_name: gpt-instruct
litellm_params: litellm_params:
model: gpt-3.5-turbo-instruct model: gpt-3.5-turbo-instruct
# api_key: my-fake-key # api_key: my-fake-key
# api_base: https://exampleopenaiendpoint-production.up.railway.app/ # api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings: # litellm_settings:
drop_params: True # drop_params: True
max_budget: 800021 # max_budget: 800021
budget_duration: 30d # budget_duration: 30d
# cache: true # # cache: true
general_settings: general_settings:

View file

@ -1468,8 +1468,8 @@ async def update_database(
payload["spend"] = response_cost payload["spend"] = response_cost
if ( if (
os.getenv("SPEND_LOGS_URL", None) is not None prisma_client is not None
and prisma_client is not None and os.getenv("SPEND_LOGS_URL", None) is not None
): ):
if isinstance(payload["startTime"], datetime): if isinstance(payload["startTime"], datetime):
payload["startTime"] = payload["startTime"].isoformat() payload["startTime"] = payload["startTime"].isoformat()
@ -1477,7 +1477,7 @@ async def update_database(
payload["endTime"] = payload["endTime"].isoformat() payload["endTime"] = payload["endTime"].isoformat()
prisma_client.spend_log_transactons.append(payload) prisma_client.spend_log_transactons.append(payload)
elif prisma_client is not None: elif prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend") prisma_client.spend_log_transactions.append(payload)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
@ -2966,7 +2966,7 @@ async def startup_event():
update_spend, update_spend,
"interval", "interval",
seconds=batch_writing_interval, seconds=batch_writing_interval,
args=[prisma_client, db_writer_client], args=[prisma_client, db_writer_client, proxy_logging_obj],
) )
scheduler.start() scheduler.start()

View file

@ -528,7 +528,7 @@ class PrismaClient:
end_user_list_transactons: dict = {} end_user_list_transactons: dict = {}
key_list_transactons: dict = {} key_list_transactons: dict = {}
team_list_transactons: dict = {} team_list_transactons: dict = {}
spend_log_transactons: List = [] spend_log_transactions: List = []
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose( print_verbose(
@ -1906,7 +1906,9 @@ async def reset_budget(prisma_client: PrismaClient):
async def update_spend( async def update_spend(
prisma_client: PrismaClient, db_writer_client: Optional[HTTPHandler] prisma_client: PrismaClient,
db_writer_client: Optional[HTTPHandler],
proxy_logging_obj: ProxyLogging,
): ):
""" """
Batch write updates to db. Batch write updates to db.
@ -1920,7 +1922,6 @@ async def update_spend(
spend_logs: list, spend_logs: list,
""" """
n_retry_times = 3 n_retry_times = 3
verbose_proxy_logger.debug("INSIDE UPDATE SPEND")
### UPDATE USER TABLE ### ### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0: if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): for i in range(n_retry_times + 1):
@ -1940,12 +1941,25 @@ async def update_spend(
prisma_client.user_list_transactons = ( prisma_client.user_list_transactons = (
{} {}
) # Clear the remaining transactions after processing all batches in the loop. ) # Clear the remaining transactions after processing all batches in the loop.
break
except httpx.ReadTimeout: except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying # Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff await asyncio.sleep(2**i) # Exponential backoff
except Exception as e: except Exception as e:
import traceback
error_msg = (
f"LiteLLM Prisma Client Exception - update user spend: {str(e)}"
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
)
)
raise e raise e
### UPDATE END-USER TABLE ### ### UPDATE END-USER TABLE ###
@ -1973,12 +1987,25 @@ async def update_spend(
prisma_client.end_user_list_transactons = ( prisma_client.end_user_list_transactons = (
{} {}
) # Clear the remaining transactions after processing all batches in the loop. ) # Clear the remaining transactions after processing all batches in the loop.
break
except httpx.ReadTimeout: except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying # Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff await asyncio.sleep(2**i) # Exponential backoff
except Exception as e: except Exception as e:
import traceback
error_msg = (
f"LiteLLM Prisma Client Exception - update end-user spend: {str(e)}"
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
)
)
raise e raise e
### UPDATE KEY TABLE ### ### UPDATE KEY TABLE ###
@ -2000,12 +2027,25 @@ async def update_spend(
prisma_client.key_list_transactons = ( prisma_client.key_list_transactons = (
{} {}
) # Clear the remaining transactions after processing all batches in the loop. ) # Clear the remaining transactions after processing all batches in the loop.
break
except httpx.ReadTimeout: except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying # Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff await asyncio.sleep(2**i) # Exponential backoff
except Exception as e: except Exception as e:
import traceback
error_msg = (
f"LiteLLM Prisma Client Exception - update key spend: {str(e)}"
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
)
)
raise e raise e
### UPDATE TEAM TABLE ### ### UPDATE TEAM TABLE ###
@ -2037,18 +2077,42 @@ async def update_spend(
prisma_client.team_list_transactons = ( prisma_client.team_list_transactons = (
{} {}
) # Clear the remaining transactions after processing all batches in the loop. ) # Clear the remaining transactions after processing all batches in the loop.
break
except httpx.ReadTimeout: except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying # Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff await asyncio.sleep(2**i) # Exponential backoff
except Exception as e: except Exception as e:
import traceback
error_msg = (
f"LiteLLM Prisma Client Exception - update team spend: {str(e)}"
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
)
)
raise e raise e
### UPDATE SPEND LOGS ### ### UPDATE SPEND LOGS ###
verbose_proxy_logger.debug(
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
)
BATCH_SIZE = 100 # Preferred size of each batch to write to the database
MAX_LOGS_PER_INTERVAL = 1000 # Maximum number of logs to flush in a single interval
if len(prisma_client.spend_log_transactions) > 0:
for _ in range(n_retry_times + 1):
try:
base_url = os.getenv("SPEND_LOGS_URL", None) base_url = os.getenv("SPEND_LOGS_URL", None)
## WRITE TO SEPARATE SERVER ##
if ( if (
len(prisma_client.spend_log_transactons) > 0 len(prisma_client.spend_log_transactions) > 0
and base_url is not None and base_url is not None
and db_writer_client is not None and db_writer_client is not None
): ):
@ -2057,19 +2121,64 @@ async def update_spend(
verbose_proxy_logger.debug("base_url: {}".format(base_url)) verbose_proxy_logger.debug("base_url: {}".format(base_url))
response = await db_writer_client.post( response = await db_writer_client.post(
url=base_url + "spend/update", url=base_url + "spend/update",
data=json.dumps(prisma_client.spend_log_transactons), # type: ignore data=json.dumps(prisma_client.spend_log_transactions), # type: ignore
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
if response.status_code == 200: if response.status_code == 200:
prisma_client.spend_log_transactons = [] prisma_client.spend_log_transactions = []
else: ## (default) WRITE TO DB ##
logs_to_process = prisma_client.spend_log_transactions[
:MAX_LOGS_PER_INTERVAL
]
for i in range(0, len(logs_to_process), BATCH_SIZE):
# Create sublist for current batch, ensuring it doesn't exceed the BATCH_SIZE
batch = logs_to_process[i : i + BATCH_SIZE]
# Convert datetime strings to Date objects
batch_with_dates = [
prisma_client.jsonify_object(
{
**entry,
}
)
for entry in batch
]
# async def monitor_spend_list(prisma_client: PrismaClient): await prisma_client.db.litellm_spendlogs.create_many(
# """ data=batch_with_dates, skip_duplicates=True # type: ignore
# Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db )
# """
# if len(prisma_client.user_list_transactons) > 10000: verbose_proxy_logger.debug(
# await update_spend(prisma_client=prisma_client) f"Flushed {len(batch)} logs to the DB."
)
# Remove the processed logs from spend_logs
prisma_client.spend_log_transactions = (
prisma_client.spend_log_transactions[len(logs_to_process) :]
)
verbose_proxy_logger.debug(
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
)
break
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
import traceback
error_msg = (
f"LiteLLM Prisma Client Exception - update spend logs: {str(e)}"
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
)
)
raise e
async def _read_request_body(request): async def _read_request_body(request):

View file

@ -51,7 +51,7 @@ from litellm.proxy.proxy_server import (
user_info, user_info,
info_key_fn, info_key_fn,
) )
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.setLevel(level=logging.DEBUG) verbose_proxy_logger.setLevel(level=logging.DEBUG)
@ -1141,9 +1141,9 @@ def test_call_with_key_over_budget(prisma_client):
from litellm.caching import Cache from litellm.caching import Cache
litellm.cache = Cache() litellm.cache = Cache()
import time import time, uuid
request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}"
resp = ModelResponse( resp = ModelResponse(
id=request_id, id=request_id,
@ -1176,7 +1176,11 @@ def test_call_with_key_over_budget(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(10) await update_spend(
prisma_client=prisma_client,
db_writer_client=None,
proxy_logging_obj=proxy_logging_obj,
)
# test spend_log was written and we can read it # test spend_log was written and we can read it
spend_logs = await view_spend_logs(request_id=request_id) spend_logs = await view_spend_logs(request_id=request_id)
@ -1202,7 +1206,10 @@ def test_call_with_key_over_budget(prisma_client):
except Exception as e: except Exception as e:
# print(f"Error - {str(e)}") # print(f"Error - {str(e)}")
traceback.print_exc() traceback.print_exc()
if hasattr(e, "message"):
error_detail = e.message error_detail = e.message
else:
error_detail = str(e)
assert "Authentication Error, ExceededTokenBudget:" in error_detail assert "Authentication Error, ExceededTokenBudget:" in error_detail
print(vars(e)) print(vars(e))
@ -1251,9 +1258,9 @@ def test_call_with_key_over_model_budget(prisma_client):
from litellm.caching import Cache from litellm.caching import Cache
litellm.cache = Cache() litellm.cache = Cache()
import time import time, uuid
request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" request_id = f"chatcmpl-{uuid.uuid4()}"
resp = ModelResponse( resp = ModelResponse(
id=request_id, id=request_id,
@ -1286,7 +1293,11 @@ def test_call_with_key_over_model_budget(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(10) await update_spend(
prisma_client=prisma_client,
db_writer_client=None,
proxy_logging_obj=proxy_logging_obj,
)
# test spend_log was written and we can read it # test spend_log was written and we can read it
spend_logs = await view_spend_logs(request_id=request_id) spend_logs = await view_spend_logs(request_id=request_id)
@ -1344,9 +1355,9 @@ async def test_call_with_key_never_over_budget(prisma_client):
_PROXY_track_cost_callback as track_cost_callback, _PROXY_track_cost_callback as track_cost_callback,
) )
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
import time import time, uuid
request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" request_id = f"chatcmpl-{uuid.uuid4()}"
resp = ModelResponse( resp = ModelResponse(
id=request_id, id=request_id,
@ -1381,7 +1392,11 @@ async def test_call_with_key_never_over_budget(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(5) await update_spend(
prisma_client=prisma_client,
db_writer_client=None,
proxy_logging_obj=proxy_logging_obj,
)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
@ -1421,9 +1436,9 @@ async def test_call_with_key_over_budget_stream(prisma_client):
_PROXY_track_cost_callback as track_cost_callback, _PROXY_track_cost_callback as track_cost_callback,
) )
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
import time import time, uuid
request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}"
resp = ModelResponse( resp = ModelResponse(
id=request_id, id=request_id,
choices=[ choices=[
@ -1457,7 +1472,11 @@ async def test_call_with_key_over_budget_stream(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(5) await update_spend(
prisma_client=prisma_client,
db_writer_client=None,
proxy_logging_obj=proxy_logging_obj,
)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)

View file

@ -109,7 +109,7 @@ async def test_spend_logs():
key_gen = await generate_key(session=session) key_gen = await generate_key(session=session)
key = key_gen["key"] key = key_gen["key"]
response = await chat_completion(session=session, key=key) response = await chat_completion(session=session, key=key)
await asyncio.sleep(5) await asyncio.sleep(20)
await get_spend_logs(session=session, request_id=response["id"]) await get_spend_logs(session=session, request_id=response["id"])