mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #2561 from BerriAI/litellm_batch_writing_db
fix(proxy/utils.py): move to batch writing db updates
This commit is contained in:
commit
c4dbd0407e
9 changed files with 352 additions and 101 deletions
|
@ -12,6 +12,7 @@ from typing import Any, Literal, Union, BinaryIO
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import dotenv, traceback, random, asyncio, time, contextvars
|
import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
|
|
|
@ -14,7 +14,8 @@ litellm_settings:
|
||||||
cache_params:
|
cache_params:
|
||||||
type: redis
|
type: redis
|
||||||
callbacks: ["batch_redis_requests"]
|
callbacks: ["batch_redis_requests"]
|
||||||
|
# success_callbacks: ["langfuse"]
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
# database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require"
|
database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require"
|
|
@ -96,6 +96,8 @@ from litellm.proxy.utils import (
|
||||||
_is_user_proxy_admin,
|
_is_user_proxy_admin,
|
||||||
_is_projected_spend_over_limit,
|
_is_projected_spend_over_limit,
|
||||||
_get_projected_spend_over_limit,
|
_get_projected_spend_over_limit,
|
||||||
|
update_spend,
|
||||||
|
monitor_spend_list,
|
||||||
)
|
)
|
||||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||||
|
@ -277,6 +279,7 @@ litellm_proxy_admin_name = "default_user_id"
|
||||||
ui_access_mode: Literal["admin", "all"] = "all"
|
ui_access_mode: Literal["admin", "all"] = "all"
|
||||||
proxy_budget_rescheduler_min_time = 597
|
proxy_budget_rescheduler_min_time = 597
|
||||||
proxy_budget_rescheduler_max_time = 605
|
proxy_budget_rescheduler_max_time = 605
|
||||||
|
proxy_batch_write_at = 60 # in seconds
|
||||||
litellm_master_key_hash = None
|
litellm_master_key_hash = None
|
||||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||||
|
@ -995,10 +998,8 @@ async def _PROXY_track_cost_callback(
|
||||||
)
|
)
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||||
user_id = proxy_server_request.get("body", {}).get("user", None)
|
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||||
user_id = user_id or kwargs["litellm_params"]["metadata"].get(
|
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
||||||
"user_api_key_user_id", None
|
|
||||||
)
|
|
||||||
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
||||||
if kwargs.get("response_cost", None) is not None:
|
if kwargs.get("response_cost", None) is not None:
|
||||||
response_cost = kwargs["response_cost"]
|
response_cost = kwargs["response_cost"]
|
||||||
|
@ -1012,9 +1013,6 @@ async def _PROXY_track_cost_callback(
|
||||||
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.info(
|
|
||||||
f"response_cost {response_cost}, for user_id {user_id}"
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}"
|
f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}"
|
||||||
)
|
)
|
||||||
|
@ -1024,6 +1022,7 @@ async def _PROXY_track_cost_callback(
|
||||||
token=user_api_key,
|
token=user_api_key,
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
completion_response=completion_response,
|
completion_response=completion_response,
|
||||||
|
@ -1065,6 +1064,7 @@ async def update_database(
|
||||||
token,
|
token,
|
||||||
response_cost,
|
response_cost,
|
||||||
user_id=None,
|
user_id=None,
|
||||||
|
end_user_id=None,
|
||||||
team_id=None,
|
team_id=None,
|
||||||
kwargs=None,
|
kwargs=None,
|
||||||
completion_response=None,
|
completion_response=None,
|
||||||
|
@ -1075,6 +1075,10 @@ async def update_database(
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
|
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
|
||||||
)
|
)
|
||||||
|
if isinstance(token, str) and token.startswith("sk-"):
|
||||||
|
hashed_token = hash_token(token=token)
|
||||||
|
else:
|
||||||
|
hashed_token = token
|
||||||
|
|
||||||
### UPDATE USER SPEND ###
|
### UPDATE USER SPEND ###
|
||||||
async def _update_user_db():
|
async def _update_user_db():
|
||||||
|
@ -1083,11 +1087,6 @@ async def update_database(
|
||||||
- Update litellm-proxy-budget row (global proxy spend)
|
- Update litellm-proxy-budget row (global proxy spend)
|
||||||
"""
|
"""
|
||||||
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
|
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
|
||||||
end_user_id = None
|
|
||||||
if isinstance(token, str) and token.startswith("sk-"):
|
|
||||||
hashed_token = hash_token(token=token)
|
|
||||||
else:
|
|
||||||
hashed_token = token
|
|
||||||
existing_token_obj = await user_api_key_cache.async_get_cache(
|
existing_token_obj = await user_api_key_cache.async_get_cache(
|
||||||
key=hashed_token
|
key=hashed_token
|
||||||
)
|
)
|
||||||
|
@ -1096,54 +1095,25 @@ async def update_database(
|
||||||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||||
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
||||||
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
||||||
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
|
||||||
end_user_id = user_id
|
|
||||||
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name]
|
|
||||||
data_list = []
|
data_list = []
|
||||||
try:
|
try:
|
||||||
if prisma_client is not None: # update
|
if prisma_client is not None: # update
|
||||||
user_ids = [user_id, litellm_proxy_budget_name]
|
user_ids = [user_id]
|
||||||
## do a group update for the user-id of the key + global proxy budget
|
if (
|
||||||
await prisma_client.db.litellm_usertable.update_many(
|
litellm.max_budget > 0
|
||||||
where={"user_id": {"in": user_ids}},
|
): # track global proxy budget, if user set max budget
|
||||||
data={"spend": {"increment": response_cost}},
|
user_ids.append(litellm_proxy_budget_name)
|
||||||
)
|
### KEY CHANGE ###
|
||||||
|
for _id in user_ids:
|
||||||
|
prisma_client.user_list_transactons[_id] = (
|
||||||
|
response_cost
|
||||||
|
+ prisma_client.user_list_transactons.get(_id, 0)
|
||||||
|
)
|
||||||
if end_user_id is not None:
|
if end_user_id is not None:
|
||||||
if existing_user_obj is None:
|
prisma_client.end_user_list_transactons[end_user_id] = (
|
||||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
response_cost
|
||||||
existing_spend = 0
|
+ prisma_client.user_list_transactons.get(end_user_id, 0)
|
||||||
max_user_budget = None
|
|
||||||
if litellm.max_user_budget is not None:
|
|
||||||
max_user_budget = litellm.max_user_budget
|
|
||||||
existing_user_obj = LiteLLM_UserTable(
|
|
||||||
user_id=end_user_id,
|
|
||||||
spend=0,
|
|
||||||
max_budget=max_user_budget,
|
|
||||||
user_email=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
existing_user_obj.spend = (
|
|
||||||
existing_user_obj.spend + response_cost
|
|
||||||
)
|
|
||||||
|
|
||||||
user_object_json = {**existing_user_obj.json(exclude_none=True)}
|
|
||||||
|
|
||||||
user_object_json["model_max_budget"] = json.dumps(
|
|
||||||
user_object_json["model_max_budget"]
|
|
||||||
)
|
)
|
||||||
user_object_json["model_spend"] = json.dumps(
|
|
||||||
user_object_json["model_spend"]
|
|
||||||
)
|
|
||||||
|
|
||||||
await prisma_client.db.litellm_usertable.upsert(
|
|
||||||
where={"user_id": end_user_id},
|
|
||||||
data={
|
|
||||||
"create": user_object_json,
|
|
||||||
"update": {"spend": {"increment": response_cost}},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
for id in user_ids:
|
for id in user_ids:
|
||||||
if id is None:
|
if id is None:
|
||||||
|
@ -1205,6 +1175,7 @@ async def update_database(
|
||||||
value={"spend": new_spend},
|
value={"spend": new_spend},
|
||||||
table_name="user",
|
table_name="user",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
"\033[91m"
|
"\033[91m"
|
||||||
|
@ -1215,12 +1186,12 @@ async def update_database(
|
||||||
async def _update_key_db():
|
async def _update_key_db():
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"adding spend to key db. Response cost: {response_cost}. Token: {token}."
|
f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}."
|
||||||
)
|
)
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
await prisma_client.db.litellm_verificationtoken.update(
|
prisma_client.key_list_transactons[hashed_token] = (
|
||||||
where={"token": token},
|
response_cost
|
||||||
data={"spend": {"increment": response_cost}},
|
+ prisma_client.key_list_transactons.get(hashed_token, 0)
|
||||||
)
|
)
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
# Fetch the existing cost for the given token
|
# Fetch the existing cost for the given token
|
||||||
|
@ -1257,7 +1228,6 @@ async def update_database(
|
||||||
async def _insert_spend_log_to_db():
|
async def _insert_spend_log_to_db():
|
||||||
try:
|
try:
|
||||||
# Helper to generate payload to log
|
# Helper to generate payload to log
|
||||||
verbose_proxy_logger.debug("inserting spend log to db")
|
|
||||||
payload = get_logging_payload(
|
payload = get_logging_payload(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
response_obj=completion_response,
|
response_obj=completion_response,
|
||||||
|
@ -1268,16 +1238,13 @@ async def update_database(
|
||||||
payload["spend"] = response_cost
|
payload["spend"] = response_cost
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
await prisma_client.insert_data(data=payload, table_name="spend")
|
await prisma_client.insert_data(data=payload, table_name="spend")
|
||||||
elif custom_db_client is not None:
|
|
||||||
await custom_db_client.insert_data(payload, table_name="spend")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.debug(
|
||||||
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
### UPDATE KEY SPEND ###
|
### UPDATE TEAM SPEND ###
|
||||||
async def _update_team_db():
|
async def _update_team_db():
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -1289,9 +1256,9 @@ async def update_database(
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
await prisma_client.db.litellm_teamtable.update(
|
prisma_client.team_list_transactons[team_id] = (
|
||||||
where={"team_id": team_id},
|
response_cost
|
||||||
data={"spend": {"increment": response_cost}},
|
+ prisma_client.team_list_transactons.get(team_id, 0)
|
||||||
)
|
)
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
# Fetch the existing cost for the given token
|
# Fetch the existing cost for the given token
|
||||||
|
@ -1327,7 +1294,8 @@ async def update_database(
|
||||||
asyncio.create_task(_update_user_db())
|
asyncio.create_task(_update_user_db())
|
||||||
asyncio.create_task(_update_key_db())
|
asyncio.create_task(_update_key_db())
|
||||||
asyncio.create_task(_update_team_db())
|
asyncio.create_task(_update_team_db())
|
||||||
asyncio.create_task(_insert_spend_log_to_db())
|
# asyncio.create_task(_insert_spend_log_to_db())
|
||||||
|
await _insert_spend_log_to_db()
|
||||||
|
|
||||||
verbose_proxy_logger.debug("Runs spend update on all tables")
|
verbose_proxy_logger.debug("Runs spend update on all tables")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1646,7 +1614,7 @@ class ProxyConfig:
|
||||||
"""
|
"""
|
||||||
Load config values into proxy global state
|
Load config values into proxy global state
|
||||||
"""
|
"""
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at
|
||||||
|
|
||||||
# Load existing config
|
# Load existing config
|
||||||
config = await self.get_config(config_file_path=config_file_path)
|
config = await self.get_config(config_file_path=config_file_path)
|
||||||
|
@ -2010,6 +1978,10 @@ class ProxyConfig:
|
||||||
proxy_budget_rescheduler_max_time = general_settings.get(
|
proxy_budget_rescheduler_max_time = general_settings.get(
|
||||||
"proxy_budget_rescheduler_max_time", proxy_budget_rescheduler_max_time
|
"proxy_budget_rescheduler_max_time", proxy_budget_rescheduler_max_time
|
||||||
)
|
)
|
||||||
|
## BATCH WRITER ##
|
||||||
|
proxy_batch_write_at = general_settings.get(
|
||||||
|
"proxy_batch_write_at", proxy_batch_write_at
|
||||||
|
)
|
||||||
### BACKGROUND HEALTH CHECKS ###
|
### BACKGROUND HEALTH CHECKS ###
|
||||||
# Enable background health checks
|
# Enable background health checks
|
||||||
use_background_health_checks = general_settings.get(
|
use_background_health_checks = general_settings.get(
|
||||||
|
@ -2238,7 +2210,6 @@ async def generate_key_helper_fn(
|
||||||
saved_token["expires"] = saved_token["expires"].isoformat()
|
saved_token["expires"] = saved_token["expires"].isoformat()
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
## CREATE USER (If necessary)
|
## CREATE USER (If necessary)
|
||||||
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
|
|
||||||
if query_type == "insert_data":
|
if query_type == "insert_data":
|
||||||
user_row = await prisma_client.insert_data(
|
user_row = await prisma_client.insert_data(
|
||||||
data=user_data, table_name="user"
|
data=user_data, table_name="user"
|
||||||
|
@ -2576,7 +2547,6 @@ async def startup_event():
|
||||||
# add master key to db
|
# add master key to db
|
||||||
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
||||||
litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID")
|
litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID")
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
generate_key_helper_fn(
|
generate_key_helper_fn(
|
||||||
duration=None,
|
duration=None,
|
||||||
|
@ -2638,9 +2608,18 @@ async def startup_event():
|
||||||
interval = random.randint(
|
interval = random.randint(
|
||||||
proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time
|
proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time
|
||||||
) # random interval, so multiple workers avoid resetting budget at the same time
|
) # random interval, so multiple workers avoid resetting budget at the same time
|
||||||
|
batch_writing_interval = random.randint(
|
||||||
|
proxy_batch_write_at - 3, proxy_batch_write_at + 3
|
||||||
|
) # random interval, so multiple workers avoid batch writing at the same time
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
reset_budget, "interval", seconds=interval, args=[prisma_client]
|
reset_budget, "interval", seconds=interval, args=[prisma_client]
|
||||||
)
|
)
|
||||||
|
scheduler.add_job(
|
||||||
|
update_spend,
|
||||||
|
"interval",
|
||||||
|
seconds=batch_writing_interval,
|
||||||
|
args=[prisma_client],
|
||||||
|
)
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
litellm_client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key="sk-1234")
|
litellm_client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key="sk-1234")
|
||||||
|
|
||||||
|
|
||||||
async def litellm_completion():
|
async def litellm_completion():
|
||||||
# Your existing code for litellm_completion goes here
|
# Your existing code for litellm_completion goes here
|
||||||
try:
|
try:
|
||||||
|
@ -18,6 +19,7 @@ async def litellm_completion():
|
||||||
"content": f"{text}. Who was alexander the great? {uuid.uuid4()}",
|
"content": f"{text}. Who was alexander the great? {uuid.uuid4()}",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
user="my-new-end-user-1",
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -29,9 +31,9 @@ async def litellm_completion():
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
for i in range(6):
|
for i in range(3):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
n = 20 # Number of concurrent tasks
|
n = 10 # Number of concurrent tasks
|
||||||
tasks = [litellm_completion() for _ in range(n)]
|
tasks = [litellm_completion() for _ in range(n)]
|
||||||
|
|
||||||
chat_completions = await asyncio.gather(*tasks)
|
chat_completions = await asyncio.gather(*tasks)
|
||||||
|
|
|
@ -7,6 +7,10 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_VerificationToken,
|
LiteLLM_VerificationToken,
|
||||||
LiteLLM_VerificationTokenView,
|
LiteLLM_VerificationTokenView,
|
||||||
LiteLLM_SpendLogs,
|
LiteLLM_SpendLogs,
|
||||||
|
LiteLLM_UserTable,
|
||||||
|
LiteLLM_EndUserTable,
|
||||||
|
LiteLLM_TeamTable,
|
||||||
|
Member,
|
||||||
)
|
)
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
|
@ -472,6 +476,12 @@ def on_backoff(details):
|
||||||
|
|
||||||
|
|
||||||
class PrismaClient:
|
class PrismaClient:
|
||||||
|
user_list_transactons: dict = {}
|
||||||
|
end_user_list_transactons: dict = {}
|
||||||
|
key_list_transactons: dict = {}
|
||||||
|
team_list_transactons: dict = {}
|
||||||
|
spend_log_transactons: List = []
|
||||||
|
|
||||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||||
|
@ -1841,6 +1851,141 @@ async def reset_budget(prisma_client: PrismaClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_spend(
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Batch write updates to db.
|
||||||
|
|
||||||
|
Triggered every minute.
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
user_id_list: dict,
|
||||||
|
keys_list: list,
|
||||||
|
team_list: list,
|
||||||
|
spend_logs: list,
|
||||||
|
"""
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"ENTERS UPDATE SPEND - len(prisma_client.user_list_transactons.keys()): {len(prisma_client.user_list_transactons.keys())}"
|
||||||
|
)
|
||||||
|
n_retry_times = 3
|
||||||
|
### UPDATE USER TABLE ###
|
||||||
|
if len(prisma_client.user_list_transactons.keys()) > 0:
|
||||||
|
for i in range(n_retry_times + 1):
|
||||||
|
try:
|
||||||
|
async with prisma_client.db.tx(timeout=6000) as transaction:
|
||||||
|
async with transaction.batch_() as batcher:
|
||||||
|
for (
|
||||||
|
user_id,
|
||||||
|
response_cost,
|
||||||
|
) in prisma_client.user_list_transactons.items():
|
||||||
|
batcher.litellm_usertable.update_many(
|
||||||
|
where={"user_id": user_id},
|
||||||
|
data={"spend": {"increment": response_cost}},
|
||||||
|
)
|
||||||
|
prisma_client.user_list_transactons = (
|
||||||
|
{}
|
||||||
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
|
except httpx.ReadTimeout:
|
||||||
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
|
raise # Re-raise the last exception
|
||||||
|
# Optionally, sleep for a bit before retrying
|
||||||
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
### UPDATE END-USER TABLE ###
|
||||||
|
if len(prisma_client.end_user_list_transactons.keys()) > 0:
|
||||||
|
for i in range(n_retry_times + 1):
|
||||||
|
try:
|
||||||
|
async with prisma_client.db.tx(timeout=6000) as transaction:
|
||||||
|
async with transaction.batch_() as batcher:
|
||||||
|
for (
|
||||||
|
end_user_id,
|
||||||
|
response_cost,
|
||||||
|
) in prisma_client.end_user_list_transactons.items():
|
||||||
|
max_user_budget = None
|
||||||
|
if litellm.max_user_budget is not None:
|
||||||
|
max_user_budget = litellm.max_user_budget
|
||||||
|
new_user_obj = LiteLLM_EndUserTable(
|
||||||
|
user_id=end_user_id, spend=response_cost, blocked=False
|
||||||
|
)
|
||||||
|
batcher.litellm_endusertable.update_many(
|
||||||
|
where={"user_id": end_user_id},
|
||||||
|
data={"spend": {"increment": response_cost}},
|
||||||
|
)
|
||||||
|
prisma_client.end_user_list_transactons = (
|
||||||
|
{}
|
||||||
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
|
except httpx.ReadTimeout:
|
||||||
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
|
raise # Re-raise the last exception
|
||||||
|
# Optionally, sleep for a bit before retrying
|
||||||
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
### UPDATE KEY TABLE ###
|
||||||
|
if len(prisma_client.key_list_transactons.keys()) > 0:
|
||||||
|
for i in range(n_retry_times + 1):
|
||||||
|
try:
|
||||||
|
async with prisma_client.db.tx(timeout=6000) as transaction:
|
||||||
|
async with transaction.batch_() as batcher:
|
||||||
|
for (
|
||||||
|
token,
|
||||||
|
response_cost,
|
||||||
|
) in prisma_client.key_list_transactons.items():
|
||||||
|
batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||||
|
where={"token": token},
|
||||||
|
data={"spend": {"increment": response_cost}},
|
||||||
|
)
|
||||||
|
prisma_client.key_list_transactons = (
|
||||||
|
{}
|
||||||
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
|
except httpx.ReadTimeout:
|
||||||
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
|
raise # Re-raise the last exception
|
||||||
|
# Optionally, sleep for a bit before retrying
|
||||||
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
### UPDATE TEAM TABLE ###
|
||||||
|
if len(prisma_client.team_list_transactons.keys()) > 0:
|
||||||
|
for i in range(n_retry_times + 1):
|
||||||
|
try:
|
||||||
|
async with prisma_client.db.tx(timeout=6000) as transaction:
|
||||||
|
async with transaction.batch_() as batcher:
|
||||||
|
for (
|
||||||
|
team_id,
|
||||||
|
response_cost,
|
||||||
|
) in prisma_client.team_list_transactons.items():
|
||||||
|
batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||||
|
where={"team_id": team_id},
|
||||||
|
data={"spend": {"increment": response_cost}},
|
||||||
|
)
|
||||||
|
prisma_client.team_list_transactons = (
|
||||||
|
{}
|
||||||
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
|
except httpx.ReadTimeout:
|
||||||
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
|
raise # Re-raise the last exception
|
||||||
|
# Optionally, sleep for a bit before retrying
|
||||||
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
### UPDATE SPEND LOGS TABLE ###
|
||||||
|
|
||||||
|
|
||||||
|
async def monitor_spend_list(prisma_client: PrismaClient):
|
||||||
|
"""
|
||||||
|
Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db
|
||||||
|
"""
|
||||||
|
if len(prisma_client.user_list_transactons) > 10000:
|
||||||
|
await update_spend(prisma_client=prisma_client)
|
||||||
|
|
||||||
|
|
||||||
async def _read_request_body(request):
|
async def _read_request_body(request):
|
||||||
"""
|
"""
|
||||||
Asynchronous function to read the request body and parse it as JSON or literal data.
|
Asynchronous function to read the request body and parse it as JSON or literal data.
|
||||||
|
|
|
@ -374,7 +374,8 @@ def test_gemini_pro_vision_base64():
|
||||||
print(resp)
|
print(resp)
|
||||||
|
|
||||||
prompt_tokens = resp.usage.prompt_tokens
|
prompt_tokens = resp.usage.prompt_tokens
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "500 Internal error encountered.'" in str(e):
|
if "500 Internal error encountered.'" in str(e):
|
||||||
pass
|
pass
|
||||||
|
@ -457,33 +458,43 @@ def test_gemini_pro_function_calling_streaming():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pro_async_function_calling():
|
async def test_gemini_pro_async_function_calling():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
tools = [
|
try:
|
||||||
{
|
tools = [
|
||||||
"type": "function",
|
{
|
||||||
"function": {
|
"type": "function",
|
||||||
"name": "get_current_weather",
|
"function": {
|
||||||
"description": "Get the current weather in a given location",
|
"name": "get_current_weather",
|
||||||
"parameters": {
|
"description": "Get the current weather in a given location",
|
||||||
"type": "object",
|
"parameters": {
|
||||||
"properties": {
|
"type": "object",
|
||||||
"location": {
|
"properties": {
|
||||||
"type": "string",
|
"location": {
|
||||||
"description": "The city and state, e.g. San Francisco, CA",
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
"required": ["location"],
|
||||||
},
|
},
|
||||||
"required": ["location"],
|
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
}
|
]
|
||||||
]
|
messages = [
|
||||||
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
{"role": "user", "content": "What's the weather like in Boston today?"}
|
||||||
completion = await litellm.acompletion(
|
]
|
||||||
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
completion = await litellm.acompletion(
|
||||||
)
|
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||||
print(f"completion: {completion}")
|
)
|
||||||
assert completion.choices[0].message.content is None
|
print(f"completion: {completion}")
|
||||||
assert len(completion.choices[0].message.tool_calls) == 1
|
assert completion.choices[0].message.content is None
|
||||||
|
assert len(completion.choices[0].message.tool_calls) == 1
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
# raise Exception("it worked!")
|
# raise Exception("it worked!")
|
||||||
|
|
||||||
|
|
||||||
|
@ -499,6 +510,8 @@ def test_vertexai_embedding():
|
||||||
input=["good morning from litellm", "this is another item"],
|
input=["good morning from litellm", "this is another item"],
|
||||||
)
|
)
|
||||||
print(f"response:", response)
|
print(f"response:", response)
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -513,6 +526,8 @@ async def test_vertexai_aembedding():
|
||||||
input=["good morning from litellm", "this is another item"],
|
input=["good morning from litellm", "this is another item"],
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
95
litellm/tests/test_update_spend.py
Normal file
95
litellm/tests/test_update_spend.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
# What is this?
|
||||||
|
## This tests the batch update spend logic on the proxy server
|
||||||
|
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm import Router, mock_completion
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
|
||||||
|
|
||||||
|
import pytest, logging, asyncio
|
||||||
|
import litellm, asyncio
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
new_user,
|
||||||
|
generate_key_fn,
|
||||||
|
user_api_key_auth,
|
||||||
|
user_update,
|
||||||
|
delete_key_fn,
|
||||||
|
info_key_fn,
|
||||||
|
update_key_fn,
|
||||||
|
generate_key_fn,
|
||||||
|
generate_key_helper_fn,
|
||||||
|
spend_user_fn,
|
||||||
|
spend_key_fn,
|
||||||
|
view_spend_logs,
|
||||||
|
user_info,
|
||||||
|
block_user,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
|
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
||||||
|
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
NewUserRequest,
|
||||||
|
GenerateKeyRequest,
|
||||||
|
DynamoDBArgs,
|
||||||
|
KeyRequest,
|
||||||
|
UpdateKeyRequest,
|
||||||
|
GenerateKeyRequest,
|
||||||
|
BlockUsers,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import DBClient
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prisma_client():
|
||||||
|
from litellm.proxy.proxy_cli import append_query_params
|
||||||
|
|
||||||
|
### add connection pool + pool timeout args
|
||||||
|
params = {"connection_limit": 100, "pool_timeout": 60}
|
||||||
|
database_url = os.getenv("DATABASE_URL")
|
||||||
|
modified_url = append_query_params(database_url, params)
|
||||||
|
os.environ["DATABASE_URL"] = modified_url
|
||||||
|
|
||||||
|
# Assuming DBClient is a class that needs to be instantiated
|
||||||
|
prisma_client = PrismaClient(
|
||||||
|
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset litellm.proxy.proxy_server.prisma_client to None
|
||||||
|
litellm.proxy.proxy_server.custom_db_client = None
|
||||||
|
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
|
||||||
|
f"litellm-proxy-budget-{time.time()}"
|
||||||
|
)
|
||||||
|
litellm.proxy.proxy_server.user_custom_key_generate = None
|
||||||
|
|
||||||
|
return prisma_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_update_spend(prisma_client):
|
||||||
|
prisma_client.user_list_transactons["test-litellm-user-5"] = 23
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
await update_spend(prisma_client=litellm.proxy.proxy_server.prisma_client)
|
|
@ -50,6 +50,7 @@ general_settings:
|
||||||
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
||||||
proxy_budget_rescheduler_min_time: 60
|
proxy_budget_rescheduler_min_time: 60
|
||||||
proxy_budget_rescheduler_max_time: 64
|
proxy_budget_rescheduler_max_time: 64
|
||||||
|
proxy_batch_write_at: 1
|
||||||
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
|
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
|
||||||
|
|
||||||
# environment_variables:
|
# environment_variables:
|
||||||
|
|
|
@ -329,6 +329,16 @@ async def test_key_info_spend_values():
|
||||||
- make completion call
|
- make completion call
|
||||||
- assert cost is expected value
|
- assert cost is expected value
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def retry_request(func, *args, _max_attempts=5, **kwargs):
|
||||||
|
for attempt in range(_max_attempts):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except aiohttp.client_exceptions.ClientOSError as e:
|
||||||
|
if attempt + 1 == _max_attempts:
|
||||||
|
raise # re-raise the last ClientOSError if all attempts failed
|
||||||
|
print(f"Attempt {attempt+1} failed, retrying...")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
## Test Spend Update ##
|
## Test Spend Update ##
|
||||||
# completion
|
# completion
|
||||||
|
@ -336,7 +346,9 @@ async def test_key_info_spend_values():
|
||||||
key = key_gen["key"]
|
key = key_gen["key"]
|
||||||
response = await chat_completion(session=session, key=key)
|
response = await chat_completion(session=session, key=key)
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
spend_logs = await get_spend_logs(session=session, request_id=response["id"])
|
spend_logs = await retry_request(
|
||||||
|
get_spend_logs, session=session, request_id=response["id"]
|
||||||
|
)
|
||||||
print(f"spend_logs: {spend_logs}")
|
print(f"spend_logs: {spend_logs}")
|
||||||
completion_tokens = spend_logs[0]["completion_tokens"]
|
completion_tokens = spend_logs[0]["completion_tokens"]
|
||||||
prompt_tokens = spend_logs[0]["prompt_tokens"]
|
prompt_tokens = spend_logs[0]["prompt_tokens"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue