forked from phoenix/litellm-mirror
fix(proxy/utils.py): move to batch writing db updates
This commit is contained in:
parent
710efab0de
commit
077b9c6234
4 changed files with 269 additions and 95 deletions
|
@ -14,7 +14,8 @@ litellm_settings:
|
||||||
cache_params:
|
cache_params:
|
||||||
type: redis
|
type: redis
|
||||||
callbacks: ["batch_redis_requests"]
|
callbacks: ["batch_redis_requests"]
|
||||||
|
# success_callbacks: ["langfuse"]
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
# database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require"
|
database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require"
|
|
@ -96,6 +96,8 @@ from litellm.proxy.utils import (
|
||||||
_is_user_proxy_admin,
|
_is_user_proxy_admin,
|
||||||
_is_projected_spend_over_limit,
|
_is_projected_spend_over_limit,
|
||||||
_get_projected_spend_over_limit,
|
_get_projected_spend_over_limit,
|
||||||
|
update_spend,
|
||||||
|
monitor_spend_list,
|
||||||
)
|
)
|
||||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||||
|
@ -1102,108 +1104,112 @@ async def update_database(
|
||||||
try:
|
try:
|
||||||
if prisma_client is not None: # update
|
if prisma_client is not None: # update
|
||||||
user_ids = [user_id, litellm_proxy_budget_name]
|
user_ids = [user_id, litellm_proxy_budget_name]
|
||||||
|
### KEY CHANGE ###
|
||||||
|
for _id in user_ids:
|
||||||
|
prisma_client.user_list_transactons.append((_id, response_cost))
|
||||||
|
######
|
||||||
## do a group update for the user-id of the key + global proxy budget
|
## do a group update for the user-id of the key + global proxy budget
|
||||||
await prisma_client.db.litellm_usertable.update_many(
|
# await prisma_client.db.litellm_usertable.update_many(
|
||||||
where={"user_id": {"in": user_ids}},
|
# where={"user_id": {"in": user_ids}},
|
||||||
data={"spend": {"increment": response_cost}},
|
# data={"spend": {"increment": response_cost}},
|
||||||
)
|
# )
|
||||||
if end_user_id is not None:
|
# if end_user_id is not None:
|
||||||
if existing_user_obj is None:
|
# if existing_user_obj is None:
|
||||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
# # if user does not exist in LiteLLM_UserTable, create a new user
|
||||||
existing_spend = 0
|
# existing_spend = 0
|
||||||
max_user_budget = None
|
# max_user_budget = None
|
||||||
if litellm.max_user_budget is not None:
|
# if litellm.max_user_budget is not None:
|
||||||
max_user_budget = litellm.max_user_budget
|
# max_user_budget = litellm.max_user_budget
|
||||||
existing_user_obj = LiteLLM_UserTable(
|
# existing_user_obj = LiteLLM_UserTable(
|
||||||
user_id=end_user_id,
|
# user_id=end_user_id,
|
||||||
spend=0,
|
# spend=0,
|
||||||
max_budget=max_user_budget,
|
# max_budget=max_user_budget,
|
||||||
user_email=None,
|
# user_email=None,
|
||||||
)
|
# )
|
||||||
|
|
||||||
else:
|
# else:
|
||||||
existing_user_obj.spend = (
|
# existing_user_obj.spend = (
|
||||||
existing_user_obj.spend + response_cost
|
# existing_user_obj.spend + response_cost
|
||||||
)
|
# )
|
||||||
|
|
||||||
user_object_json = {**existing_user_obj.json(exclude_none=True)}
|
# user_object_json = {**existing_user_obj.json(exclude_none=True)}
|
||||||
|
|
||||||
user_object_json["model_max_budget"] = json.dumps(
|
# user_object_json["model_max_budget"] = json.dumps(
|
||||||
user_object_json["model_max_budget"]
|
# user_object_json["model_max_budget"]
|
||||||
)
|
# )
|
||||||
user_object_json["model_spend"] = json.dumps(
|
# user_object_json["model_spend"] = json.dumps(
|
||||||
user_object_json["model_spend"]
|
# user_object_json["model_spend"]
|
||||||
)
|
# )
|
||||||
|
|
||||||
await prisma_client.db.litellm_usertable.upsert(
|
# await prisma_client.db.litellm_usertable.upsert(
|
||||||
where={"user_id": end_user_id},
|
# where={"user_id": end_user_id},
|
||||||
data={
|
# data={
|
||||||
"create": user_object_json,
|
# "create": user_object_json,
|
||||||
"update": {"spend": {"increment": response_cost}},
|
# "update": {"spend": {"increment": response_cost}},
|
||||||
},
|
# },
|
||||||
)
|
# )
|
||||||
|
|
||||||
elif custom_db_client is not None:
|
# elif custom_db_client is not None:
|
||||||
for id in user_ids:
|
# for id in user_ids:
|
||||||
if id is None:
|
# if id is None:
|
||||||
continue
|
# continue
|
||||||
if (
|
# if (
|
||||||
custom_db_client is not None
|
# custom_db_client is not None
|
||||||
and id != litellm_proxy_budget_name
|
# and id != litellm_proxy_budget_name
|
||||||
):
|
# ):
|
||||||
existing_spend_obj = await custom_db_client.get_data(
|
# existing_spend_obj = await custom_db_client.get_data(
|
||||||
key=id, table_name="user"
|
# key=id, table_name="user"
|
||||||
)
|
# )
|
||||||
verbose_proxy_logger.debug(
|
# verbose_proxy_logger.debug(
|
||||||
f"Updating existing_spend_obj: {existing_spend_obj}"
|
# f"Updating existing_spend_obj: {existing_spend_obj}"
|
||||||
)
|
# )
|
||||||
if existing_spend_obj is None:
|
# if existing_spend_obj is None:
|
||||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
# # if user does not exist in LiteLLM_UserTable, create a new user
|
||||||
existing_spend = 0
|
# existing_spend = 0
|
||||||
max_user_budget = None
|
# max_user_budget = None
|
||||||
if litellm.max_user_budget is not None:
|
# if litellm.max_user_budget is not None:
|
||||||
max_user_budget = litellm.max_user_budget
|
# max_user_budget = litellm.max_user_budget
|
||||||
existing_spend_obj = LiteLLM_UserTable(
|
# existing_spend_obj = LiteLLM_UserTable(
|
||||||
user_id=id,
|
# user_id=id,
|
||||||
spend=0,
|
# spend=0,
|
||||||
max_budget=max_user_budget,
|
# max_budget=max_user_budget,
|
||||||
user_email=None,
|
# user_email=None,
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
existing_spend = existing_spend_obj.spend
|
# existing_spend = existing_spend_obj.spend
|
||||||
|
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
# # Calculate the new cost by adding the existing cost and response_cost
|
||||||
existing_spend_obj.spend = existing_spend + response_cost
|
# existing_spend_obj.spend = existing_spend + response_cost
|
||||||
|
|
||||||
# track cost per model, for the given user
|
# # track cost per model, for the given user
|
||||||
spend_per_model = existing_spend_obj.model_spend or {}
|
# spend_per_model = existing_spend_obj.model_spend or {}
|
||||||
current_model = kwargs.get("model")
|
# current_model = kwargs.get("model")
|
||||||
|
|
||||||
if current_model is not None and spend_per_model is not None:
|
# if current_model is not None and spend_per_model is not None:
|
||||||
if spend_per_model.get(current_model) is None:
|
# if spend_per_model.get(current_model) is None:
|
||||||
spend_per_model[current_model] = response_cost
|
# spend_per_model[current_model] = response_cost
|
||||||
else:
|
# else:
|
||||||
spend_per_model[current_model] += response_cost
|
# spend_per_model[current_model] += response_cost
|
||||||
existing_spend_obj.model_spend = spend_per_model
|
# existing_spend_obj.model_spend = spend_per_model
|
||||||
|
|
||||||
valid_token = user_api_key_cache.get_cache(key=id)
|
# valid_token = user_api_key_cache.get_cache(key=id)
|
||||||
if valid_token is not None and isinstance(valid_token, dict):
|
# if valid_token is not None and isinstance(valid_token, dict):
|
||||||
user_api_key_cache.set_cache(
|
# user_api_key_cache.set_cache(
|
||||||
key=id, value=existing_spend_obj.json()
|
# key=id, value=existing_spend_obj.json()
|
||||||
)
|
# )
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
# verbose_proxy_logger.debug(
|
||||||
f"user - new cost: {existing_spend_obj.spend}, user_id: {id}"
|
# f"user - new cost: {existing_spend_obj.spend}, user_id: {id}"
|
||||||
)
|
# )
|
||||||
data_list.append(existing_spend_obj)
|
# data_list.append(existing_spend_obj)
|
||||||
|
|
||||||
if custom_db_client is not None and user_id is not None:
|
# if custom_db_client is not None and user_id is not None:
|
||||||
new_spend = data_list[0].spend
|
# new_spend = data_list[0].spend
|
||||||
await custom_db_client.update_data(
|
# await custom_db_client.update_data(
|
||||||
key=user_id,
|
# key=user_id,
|
||||||
value={"spend": new_spend},
|
# value={"spend": new_spend},
|
||||||
table_name="user",
|
# table_name="user",
|
||||||
)
|
# )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
"\033[91m"
|
"\033[91m"
|
||||||
|
@ -1323,10 +1329,12 @@ async def update_database(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
asyncio.create_task(_update_user_db())
|
# asyncio.create_task(_update_user_db())
|
||||||
asyncio.create_task(_update_key_db())
|
await _update_user_db()
|
||||||
asyncio.create_task(_update_team_db())
|
|
||||||
asyncio.create_task(_insert_spend_log_to_db())
|
# asyncio.create_task(_update_key_db())
|
||||||
|
# asyncio.create_task(_update_team_db())
|
||||||
|
# asyncio.create_task(_insert_spend_log_to_db())
|
||||||
|
|
||||||
verbose_proxy_logger.debug("Runs spend update on all tables")
|
verbose_proxy_logger.debug("Runs spend update on all tables")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2632,6 +2640,10 @@ async def startup_event():
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
reset_budget, "interval", seconds=interval, args=[prisma_client]
|
reset_budget, "interval", seconds=interval, args=[prisma_client]
|
||||||
)
|
)
|
||||||
|
# scheduler.add_job(
|
||||||
|
# monitor_spend_list, "interval", seconds=10, args=[prisma_client]
|
||||||
|
# )
|
||||||
|
scheduler.add_job(update_spend, "interval", seconds=60, args=[prisma_client])
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -472,6 +472,11 @@ def on_backoff(details):
|
||||||
|
|
||||||
|
|
||||||
class PrismaClient:
|
class PrismaClient:
|
||||||
|
user_list_transactons: List = []
|
||||||
|
key_list_transactons: List = []
|
||||||
|
team_list_transactons: List = []
|
||||||
|
spend_log_transactons: List = []
|
||||||
|
|
||||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||||
|
@ -1841,6 +1846,65 @@ async def reset_budget(prisma_client: PrismaClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_spend(
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Batch write updates to db.
|
||||||
|
|
||||||
|
Triggered every minute.
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
user_id_list: list,
|
||||||
|
keys_list: list,
|
||||||
|
team_list: list,
|
||||||
|
spend_logs: list,
|
||||||
|
"""
|
||||||
|
n_retry_times = 3
|
||||||
|
### UPDATE USER TABLE ###
|
||||||
|
if len(prisma_client.user_list_transactons) > 0:
|
||||||
|
for i in range(n_retry_times + 1):
|
||||||
|
try:
|
||||||
|
remaining_transactions = list(prisma_client.user_list_transactons)
|
||||||
|
while remaining_transactions:
|
||||||
|
batch_size = min(5000, len(remaining_transactions))
|
||||||
|
batch_transactions = remaining_transactions[:batch_size]
|
||||||
|
async with prisma_client.db.tx(timeout=60000) as transaction:
|
||||||
|
async with transaction.batch_() as batcher:
|
||||||
|
for user_id_tuple in batch_transactions:
|
||||||
|
user_id, response_cost = user_id_tuple
|
||||||
|
if user_id != "litellm-proxy-budget":
|
||||||
|
batcher.litellm_usertable.update(
|
||||||
|
where={"user_id": user_id},
|
||||||
|
data={"spend": {"increment": response_cost}},
|
||||||
|
)
|
||||||
|
|
||||||
|
remaining_transactions = remaining_transactions[batch_size:]
|
||||||
|
|
||||||
|
prisma_client.user_list_transactons = (
|
||||||
|
[]
|
||||||
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
|
except httpx.ReadTimeout:
|
||||||
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
|
raise # Re-raise the last exception
|
||||||
|
# Optionally, sleep for a bit before retrying
|
||||||
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
### UPDATE KEY TABLE ###
|
||||||
|
### UPDATE TEAM TABLE ###
|
||||||
|
### UPDATE SPEND LOGS TABLE ###
|
||||||
|
|
||||||
|
|
||||||
|
async def monitor_spend_list(prisma_client: PrismaClient):
|
||||||
|
"""
|
||||||
|
Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db
|
||||||
|
"""
|
||||||
|
if len(prisma_client.user_list_transactons) > 10000:
|
||||||
|
await update_spend(prisma_client=prisma_client)
|
||||||
|
|
||||||
|
|
||||||
async def _read_request_body(request):
|
async def _read_request_body(request):
|
||||||
"""
|
"""
|
||||||
Asynchronous function to read the request body and parse it as JSON or literal data.
|
Asynchronous function to read the request body and parse it as JSON or literal data.
|
||||||
|
|
97
litellm/tests/test_update_spend.py
Normal file
97
litellm/tests/test_update_spend.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
# What is this?
|
||||||
|
## This tests the batch update spend logic on the proxy server
|
||||||
|
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm import Router, mock_completion
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
|
||||||
|
|
||||||
|
import pytest, logging, asyncio
|
||||||
|
import litellm, asyncio
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
new_user,
|
||||||
|
generate_key_fn,
|
||||||
|
user_api_key_auth,
|
||||||
|
user_update,
|
||||||
|
delete_key_fn,
|
||||||
|
info_key_fn,
|
||||||
|
update_key_fn,
|
||||||
|
generate_key_fn,
|
||||||
|
generate_key_helper_fn,
|
||||||
|
spend_user_fn,
|
||||||
|
spend_key_fn,
|
||||||
|
view_spend_logs,
|
||||||
|
user_info,
|
||||||
|
block_user,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
|
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
||||||
|
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
NewUserRequest,
|
||||||
|
GenerateKeyRequest,
|
||||||
|
DynamoDBArgs,
|
||||||
|
KeyRequest,
|
||||||
|
UpdateKeyRequest,
|
||||||
|
GenerateKeyRequest,
|
||||||
|
BlockUsers,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import DBClient
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prisma_client():
|
||||||
|
from litellm.proxy.proxy_cli import append_query_params
|
||||||
|
|
||||||
|
### add connection pool + pool timeout args
|
||||||
|
params = {"connection_limit": 100, "pool_timeout": 60}
|
||||||
|
database_url = os.getenv("DATABASE_URL")
|
||||||
|
modified_url = append_query_params(database_url, params)
|
||||||
|
os.environ["DATABASE_URL"] = modified_url
|
||||||
|
|
||||||
|
# Assuming DBClient is a class that needs to be instantiated
|
||||||
|
prisma_client = PrismaClient(
|
||||||
|
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset litellm.proxy.proxy_server.prisma_client to None
|
||||||
|
litellm.proxy.proxy_server.custom_db_client = None
|
||||||
|
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
|
||||||
|
f"litellm-proxy-budget-{time.time()}"
|
||||||
|
)
|
||||||
|
litellm.proxy.proxy_server.user_custom_key_generate = None
|
||||||
|
|
||||||
|
return prisma_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_update_spend(prisma_client):
|
||||||
|
prisma_client.user_list_transactons.append(("test-litellm-user-5", 23))
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
await update_spend(prisma_client=litellm.proxy.proxy_server.prisma_client)
|
||||||
|
|
||||||
|
raise Exception("it worked!")
|
Loading…
Add table
Add a link
Reference in a new issue