Merge pull request #2561 from BerriAI/litellm_batch_writing_db

fix(proxy/utils.py): move to batch writing db updates
This commit is contained in:
Krish Dholakia 2024-03-18 21:50:47 -07:00 committed by GitHub
commit c4dbd0407e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 352 additions and 101 deletions

View file

@ -12,6 +12,7 @@ from typing import Any, Literal, Union, BinaryIO
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger

View file

@ -14,7 +14,8 @@ litellm_settings:
cache_params: cache_params:
type: redis type: redis
callbacks: ["batch_redis_requests"] callbacks: ["batch_redis_requests"]
# success_callbacks: ["langfuse"]
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
# database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require" database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require"

View file

@ -96,6 +96,8 @@ from litellm.proxy.utils import (
_is_user_proxy_admin, _is_user_proxy_admin,
_is_projected_spend_over_limit, _is_projected_spend_over_limit,
_get_projected_spend_over_limit, _get_projected_spend_over_limit,
update_spend,
monitor_spend_list,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
@ -277,6 +279,7 @@ litellm_proxy_admin_name = "default_user_id"
ui_access_mode: Literal["admin", "all"] = "all" ui_access_mode: Literal["admin", "all"] = "all"
proxy_budget_rescheduler_min_time = 597 proxy_budget_rescheduler_min_time = 597
proxy_budget_rescheduler_max_time = 605 proxy_budget_rescheduler_max_time = 605
proxy_batch_write_at = 60 # in seconds
litellm_master_key_hash = None litellm_master_key_hash = None
### INITIALIZE GLOBAL LOGGING OBJECT ### ### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
@ -995,10 +998,8 @@ async def _PROXY_track_cost_callback(
) )
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {} proxy_server_request = litellm_params.get("proxy_server_request") or {}
user_id = proxy_server_request.get("body", {}).get("user", None) end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = user_id or kwargs["litellm_params"]["metadata"].get( user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
"user_api_key_user_id", None
)
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None) team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
if kwargs.get("response_cost", None) is not None: if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"] response_cost = kwargs["response_cost"]
@ -1012,9 +1013,6 @@ async def _PROXY_track_cost_callback(
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
) )
verbose_proxy_logger.info(
f"response_cost {response_cost}, for user_id {user_id}"
)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}" f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}"
) )
@ -1024,6 +1022,7 @@ async def _PROXY_track_cost_callback(
token=user_api_key, token=user_api_key,
response_cost=response_cost, response_cost=response_cost,
user_id=user_id, user_id=user_id,
end_user_id=end_user_id,
team_id=team_id, team_id=team_id,
kwargs=kwargs, kwargs=kwargs,
completion_response=completion_response, completion_response=completion_response,
@ -1065,6 +1064,7 @@ async def update_database(
token, token,
response_cost, response_cost,
user_id=None, user_id=None,
end_user_id=None,
team_id=None, team_id=None,
kwargs=None, kwargs=None,
completion_response=None, completion_response=None,
@ -1075,6 +1075,10 @@ async def update_database(
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
) )
if isinstance(token, str) and token.startswith("sk-"):
hashed_token = hash_token(token=token)
else:
hashed_token = token
### UPDATE USER SPEND ### ### UPDATE USER SPEND ###
async def _update_user_db(): async def _update_user_db():
@ -1083,11 +1087,6 @@ async def update_database(
- Update litellm-proxy-budget row (global proxy spend) - Update litellm-proxy-budget row (global proxy spend)
""" """
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
end_user_id = None
if isinstance(token, str) and token.startswith("sk-"):
hashed_token = hash_token(token=token)
else:
hashed_token = token
existing_token_obj = await user_api_key_cache.async_get_cache( existing_token_obj = await user_api_key_cache.async_get_cache(
key=hashed_token key=hashed_token
) )
@ -1096,54 +1095,25 @@ async def update_database(
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if existing_user_obj is not None and isinstance(existing_user_obj, dict): if existing_user_obj is not None and isinstance(existing_user_obj, dict):
existing_user_obj = LiteLLM_UserTable(**existing_user_obj) existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
if existing_token_obj.user_id != user_id: # an end-user id was passed in
end_user_id = user_id
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name]
data_list = [] data_list = []
try: try:
if prisma_client is not None: # update if prisma_client is not None: # update
user_ids = [user_id, litellm_proxy_budget_name] user_ids = [user_id]
## do a group update for the user-id of the key + global proxy budget if (
await prisma_client.db.litellm_usertable.update_many( litellm.max_budget > 0
where={"user_id": {"in": user_ids}}, ): # track global proxy budget, if user set max budget
data={"spend": {"increment": response_cost}}, user_ids.append(litellm_proxy_budget_name)
### KEY CHANGE ###
for _id in user_ids:
prisma_client.user_list_transactons[_id] = (
response_cost
+ prisma_client.user_list_transactons.get(_id, 0)
) )
if end_user_id is not None: if end_user_id is not None:
if existing_user_obj is None: prisma_client.end_user_list_transactons[end_user_id] = (
# if user does not exist in LiteLLM_UserTable, create a new user response_cost
existing_spend = 0 + prisma_client.user_list_transactons.get(end_user_id, 0)
max_user_budget = None
if litellm.max_user_budget is not None:
max_user_budget = litellm.max_user_budget
existing_user_obj = LiteLLM_UserTable(
user_id=end_user_id,
spend=0,
max_budget=max_user_budget,
user_email=None,
) )
else:
existing_user_obj.spend = (
existing_user_obj.spend + response_cost
)
user_object_json = {**existing_user_obj.json(exclude_none=True)}
user_object_json["model_max_budget"] = json.dumps(
user_object_json["model_max_budget"]
)
user_object_json["model_spend"] = json.dumps(
user_object_json["model_spend"]
)
await prisma_client.db.litellm_usertable.upsert(
where={"user_id": end_user_id},
data={
"create": user_object_json,
"update": {"spend": {"increment": response_cost}},
},
)
elif custom_db_client is not None: elif custom_db_client is not None:
for id in user_ids: for id in user_ids:
if id is None: if id is None:
@ -1205,6 +1175,7 @@ async def update_database(
value={"spend": new_spend}, value={"spend": new_spend},
table_name="user", table_name="user",
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.info( verbose_proxy_logger.info(
"\033[91m" "\033[91m"
@ -1215,12 +1186,12 @@ async def update_database(
async def _update_key_db(): async def _update_key_db():
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"adding spend to key db. Response cost: {response_cost}. Token: {token}." f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}."
) )
if prisma_client is not None: if prisma_client is not None:
await prisma_client.db.litellm_verificationtoken.update( prisma_client.key_list_transactons[hashed_token] = (
where={"token": token}, response_cost
data={"spend": {"increment": response_cost}}, + prisma_client.key_list_transactons.get(hashed_token, 0)
) )
elif custom_db_client is not None: elif custom_db_client is not None:
# Fetch the existing cost for the given token # Fetch the existing cost for the given token
@ -1257,7 +1228,6 @@ async def update_database(
async def _insert_spend_log_to_db(): async def _insert_spend_log_to_db():
try: try:
# Helper to generate payload to log # Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db")
payload = get_logging_payload( payload = get_logging_payload(
kwargs=kwargs, kwargs=kwargs,
response_obj=completion_response, response_obj=completion_response,
@ -1268,16 +1238,13 @@ async def update_database(
payload["spend"] = response_cost payload["spend"] = response_cost
if prisma_client is not None: if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend") await prisma_client.insert_data(data=payload, table_name="spend")
elif custom_db_client is not None:
await custom_db_client.insert_data(payload, table_name="spend")
except Exception as e: except Exception as e:
verbose_proxy_logger.info( verbose_proxy_logger.debug(
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
) )
raise e raise e
### UPDATE KEY SPEND ### ### UPDATE TEAM SPEND ###
async def _update_team_db(): async def _update_team_db():
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -1289,9 +1256,9 @@ async def update_database(
) )
return return
if prisma_client is not None: if prisma_client is not None:
await prisma_client.db.litellm_teamtable.update( prisma_client.team_list_transactons[team_id] = (
where={"team_id": team_id}, response_cost
data={"spend": {"increment": response_cost}}, + prisma_client.team_list_transactons.get(team_id, 0)
) )
elif custom_db_client is not None: elif custom_db_client is not None:
# Fetch the existing cost for the given token # Fetch the existing cost for the given token
@ -1327,7 +1294,8 @@ async def update_database(
asyncio.create_task(_update_user_db()) asyncio.create_task(_update_user_db())
asyncio.create_task(_update_key_db()) asyncio.create_task(_update_key_db())
asyncio.create_task(_update_team_db()) asyncio.create_task(_update_team_db())
asyncio.create_task(_insert_spend_log_to_db()) # asyncio.create_task(_insert_spend_log_to_db())
await _insert_spend_log_to_db()
verbose_proxy_logger.debug("Runs spend update on all tables") verbose_proxy_logger.debug("Runs spend update on all tables")
except Exception as e: except Exception as e:
@ -1646,7 +1614,7 @@ class ProxyConfig:
""" """
Load config values into proxy global state Load config values into proxy global state
""" """
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at
# Load existing config # Load existing config
config = await self.get_config(config_file_path=config_file_path) config = await self.get_config(config_file_path=config_file_path)
@ -2010,6 +1978,10 @@ class ProxyConfig:
proxy_budget_rescheduler_max_time = general_settings.get( proxy_budget_rescheduler_max_time = general_settings.get(
"proxy_budget_rescheduler_max_time", proxy_budget_rescheduler_max_time "proxy_budget_rescheduler_max_time", proxy_budget_rescheduler_max_time
) )
## BATCH WRITER ##
proxy_batch_write_at = general_settings.get(
"proxy_batch_write_at", proxy_batch_write_at
)
### BACKGROUND HEALTH CHECKS ### ### BACKGROUND HEALTH CHECKS ###
# Enable background health checks # Enable background health checks
use_background_health_checks = general_settings.get( use_background_health_checks = general_settings.get(
@ -2238,7 +2210,6 @@ async def generate_key_helper_fn(
saved_token["expires"] = saved_token["expires"].isoformat() saved_token["expires"] = saved_token["expires"].isoformat()
if prisma_client is not None: if prisma_client is not None:
## CREATE USER (If necessary) ## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
if query_type == "insert_data": if query_type == "insert_data":
user_row = await prisma_client.insert_data( user_row = await prisma_client.insert_data(
data=user_data, table_name="user" data=user_data, table_name="user"
@ -2576,7 +2547,6 @@ async def startup_event():
# add master key to db # add master key to db
if os.getenv("PROXY_ADMIN_ID", None) is not None: if os.getenv("PROXY_ADMIN_ID", None) is not None:
litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID") litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID")
asyncio.create_task( asyncio.create_task(
generate_key_helper_fn( generate_key_helper_fn(
duration=None, duration=None,
@ -2638,9 +2608,18 @@ async def startup_event():
interval = random.randint( interval = random.randint(
proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time
) # random interval, so multiple workers avoid resetting budget at the same time ) # random interval, so multiple workers avoid resetting budget at the same time
batch_writing_interval = random.randint(
proxy_batch_write_at - 3, proxy_batch_write_at + 3
) # random interval, so multiple workers avoid batch writing at the same time
scheduler.add_job( scheduler.add_job(
reset_budget, "interval", seconds=interval, args=[prisma_client] reset_budget, "interval", seconds=interval, args=[prisma_client]
) )
scheduler.add_job(
update_spend,
"interval",
seconds=batch_writing_interval,
args=[prisma_client],
)
scheduler.start() scheduler.start()

View file

@ -7,6 +7,7 @@ from dotenv import load_dotenv
litellm_client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key="sk-1234") litellm_client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key="sk-1234")
async def litellm_completion(): async def litellm_completion():
# Your existing code for litellm_completion goes here # Your existing code for litellm_completion goes here
try: try:
@ -18,6 +19,7 @@ async def litellm_completion():
"content": f"{text}. Who was alexander the great? {uuid.uuid4()}", "content": f"{text}. Who was alexander the great? {uuid.uuid4()}",
} }
], ],
user="my-new-end-user-1",
) )
return response return response
@ -29,9 +31,9 @@ async def litellm_completion():
async def main(): async def main():
for i in range(6): for i in range(3):
start = time.time() start = time.time()
n = 20 # Number of concurrent tasks n = 10 # Number of concurrent tasks
tasks = [litellm_completion() for _ in range(n)] tasks = [litellm_completion() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks) chat_completions = await asyncio.gather(*tasks)

View file

@ -7,6 +7,10 @@ from litellm.proxy._types import (
LiteLLM_VerificationToken, LiteLLM_VerificationToken,
LiteLLM_VerificationTokenView, LiteLLM_VerificationTokenView,
LiteLLM_SpendLogs, LiteLLM_SpendLogs,
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLM_TeamTable,
Member,
) )
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
@ -472,6 +476,12 @@ def on_backoff(details):
class PrismaClient: class PrismaClient:
user_list_transactons: dict = {}
end_user_list_transactons: dict = {}
key_list_transactons: dict = {}
team_list_transactons: dict = {}
spend_log_transactons: List = []
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose( print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
@ -1841,6 +1851,141 @@ async def reset_budget(prisma_client: PrismaClient):
) )
async def update_spend(
prisma_client: PrismaClient,
):
"""
Batch write updates to db.
Triggered every minute.
Requires:
user_id_list: dict,
keys_list: list,
team_list: list,
spend_logs: list,
"""
verbose_proxy_logger.debug(
f"ENTERS UPDATE SPEND - len(prisma_client.user_list_transactons.keys()): {len(prisma_client.user_list_transactons.keys())}"
)
n_retry_times = 3
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
async with prisma_client.db.tx(timeout=6000) as transaction:
async with transaction.batch_() as batcher:
for (
user_id,
response_cost,
) in prisma_client.user_list_transactons.items():
batcher.litellm_usertable.update_many(
where={"user_id": user_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.user_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
raise e
### UPDATE END-USER TABLE ###
if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
async with prisma_client.db.tx(timeout=6000) as transaction:
async with transaction.batch_() as batcher:
for (
end_user_id,
response_cost,
) in prisma_client.end_user_list_transactons.items():
max_user_budget = None
if litellm.max_user_budget is not None:
max_user_budget = litellm.max_user_budget
new_user_obj = LiteLLM_EndUserTable(
user_id=end_user_id, spend=response_cost, blocked=False
)
batcher.litellm_endusertable.update_many(
where={"user_id": end_user_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.end_user_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
raise e
### UPDATE KEY TABLE ###
if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
async with prisma_client.db.tx(timeout=6000) as transaction:
async with transaction.batch_() as batcher:
for (
token,
response_cost,
) in prisma_client.key_list_transactons.items():
batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists
where={"token": token},
data={"spend": {"increment": response_cost}},
)
prisma_client.key_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
raise e
### UPDATE TEAM TABLE ###
if len(prisma_client.team_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
async with prisma_client.db.tx(timeout=6000) as transaction:
async with transaction.batch_() as batcher:
for (
team_id,
response_cost,
) in prisma_client.team_list_transactons.items():
batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists
where={"team_id": team_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.team_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
raise e
### UPDATE SPEND LOGS TABLE ###
async def monitor_spend_list(prisma_client: PrismaClient):
"""
Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db
"""
if len(prisma_client.user_list_transactons) > 10000:
await update_spend(prisma_client=prisma_client)
async def _read_request_body(request): async def _read_request_body(request):
""" """
Asynchronous function to read the request body and parse it as JSON or literal data. Asynchronous function to read the request body and parse it as JSON or literal data.

View file

@ -374,7 +374,8 @@ def test_gemini_pro_vision_base64():
print(resp) print(resp)
prompt_tokens = resp.usage.prompt_tokens prompt_tokens = resp.usage.prompt_tokens
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
if "500 Internal error encountered.'" in str(e): if "500 Internal error encountered.'" in str(e):
pass pass
@ -457,6 +458,7 @@ def test_gemini_pro_function_calling_streaming():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_pro_async_function_calling(): async def test_gemini_pro_async_function_calling():
load_vertex_ai_credentials() load_vertex_ai_credentials()
try:
tools = [ tools = [
{ {
"type": "function", "type": "function",
@ -470,20 +472,29 @@ async def test_gemini_pro_async_function_calling():
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA", "description": "The city and state, e.g. San Francisco, CA",
}, },
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, "unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
}, },
"required": ["location"], "required": ["location"],
}, },
}, },
} }
] ]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] messages = [
{"role": "user", "content": "What's the weather like in Boston today?"}
]
completion = await litellm.acompletion( completion = await litellm.acompletion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
) )
print(f"completion: {completion}") print(f"completion: {completion}")
assert completion.choices[0].message.content is None assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1 assert len(completion.choices[0].message.tool_calls) == 1
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# raise Exception("it worked!") # raise Exception("it worked!")
@ -499,6 +510,8 @@ def test_vertexai_embedding():
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
) )
print(f"response:", response) print(f"response:", response)
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -513,6 +526,8 @@ async def test_vertexai_aembedding():
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
) )
print(f"response: {response}") print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -0,0 +1,95 @@
# What is this?
## This tests the batch update spend logic on the proxy server
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
from fastapi import Request
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
import pytest, logging, asyncio
import litellm, asyncio
from litellm.proxy.proxy_server import (
new_user,
generate_key_fn,
user_api_key_auth,
user_update,
delete_key_fn,
info_key_fn,
update_key_fn,
generate_key_fn,
generate_key_helper_fn,
spend_user_fn,
spend_key_fn,
view_spend_logs,
user_info,
block_user,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.setLevel(level=logging.DEBUG)
from litellm.proxy._types import (
NewUserRequest,
GenerateKeyRequest,
DynamoDBArgs,
KeyRequest,
UpdateKeyRequest,
GenerateKeyRequest,
BlockUsers,
)
from litellm.proxy.utils import DBClient
from starlette.datastructures import URL
from litellm.caching import DualCache
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
@pytest.fixture
def prisma_client():
from litellm.proxy.proxy_cli import append_query_params
### add connection pool + pool timeout args
params = {"connection_limit": 100, "pool_timeout": 60}
database_url = os.getenv("DATABASE_URL")
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
# Assuming DBClient is a class that needs to be instantiated
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
# Reset litellm.proxy.proxy_server.prisma_client to None
litellm.proxy.proxy_server.custom_db_client = None
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
f"litellm-proxy-budget-{time.time()}"
)
litellm.proxy.proxy_server.user_custom_key_generate = None
return prisma_client
@pytest.mark.asyncio
async def test_batch_update_spend(prisma_client):
prisma_client.user_list_transactons["test-litellm-user-5"] = 23
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
await update_spend(prisma_client=litellm.proxy.proxy_server.prisma_client)

View file

@ -50,6 +50,7 @@ general_settings:
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
proxy_budget_rescheduler_min_time: 60 proxy_budget_rescheduler_min_time: 60
proxy_budget_rescheduler_max_time: 64 proxy_budget_rescheduler_max_time: 64
proxy_batch_write_at: 1
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy # database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
# environment_variables: # environment_variables:

View file

@ -329,6 +329,16 @@ async def test_key_info_spend_values():
- make completion call - make completion call
- assert cost is expected value - assert cost is expected value
""" """
async def retry_request(func, *args, _max_attempts=5, **kwargs):
for attempt in range(_max_attempts):
try:
return await func(*args, **kwargs)
except aiohttp.client_exceptions.ClientOSError as e:
if attempt + 1 == _max_attempts:
raise # re-raise the last ClientOSError if all attempts failed
print(f"Attempt {attempt+1} failed, retrying...")
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
## Test Spend Update ## ## Test Spend Update ##
# completion # completion
@ -336,7 +346,9 @@ async def test_key_info_spend_values():
key = key_gen["key"] key = key_gen["key"]
response = await chat_completion(session=session, key=key) response = await chat_completion(session=session, key=key)
await asyncio.sleep(5) await asyncio.sleep(5)
spend_logs = await get_spend_logs(session=session, request_id=response["id"]) spend_logs = await retry_request(
get_spend_logs, session=session, request_id=response["id"]
)
print(f"spend_logs: {spend_logs}") print(f"spend_logs: {spend_logs}")
completion_tokens = spend_logs[0]["completion_tokens"] completion_tokens = spend_logs[0]["completion_tokens"]
prompt_tokens = spend_logs[0]["prompt_tokens"] prompt_tokens = spend_logs[0]["prompt_tokens"]