From 077b9c6234f264063ba03f632c03a70422c0269a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 16 Mar 2024 22:32:00 -0700 Subject: [PATCH 1/7] fix(proxy/utils.py): move to batch writing db updates --- litellm/proxy/_new_secret_config.yaml | 3 +- litellm/proxy/proxy_server.py | 200 ++++++++++++++------------ litellm/proxy/utils.py | 64 +++++++++ litellm/tests/test_update_spend.py | 97 +++++++++++++ 4 files changed, 269 insertions(+), 95 deletions(-) create mode 100644 litellm/tests/test_update_spend.py diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 1c41d79fc..bd277bbdf 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -14,7 +14,8 @@ litellm_settings: cache_params: type: redis callbacks: ["batch_redis_requests"] + # success_callbacks: ["langfuse"] general_settings: master_key: sk-1234 - # database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require" \ No newline at end of file + database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require" \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a4da4e4a8..78f37ad34 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -96,6 +96,8 @@ from litellm.proxy.utils import ( _is_user_proxy_admin, _is_projected_spend_over_limit, _get_projected_spend_over_limit, + update_spend, + monitor_spend_list, ) from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager @@ -1102,108 +1104,112 @@ async def update_database( try: if prisma_client is not None: # update user_ids = [user_id, litellm_proxy_budget_name] + ### KEY CHANGE ### + for _id in user_ids: + prisma_client.user_list_transactons.append((_id, response_cost)) + ###### ## do a group update for the user-id of the key + global proxy budget - await prisma_client.db.litellm_usertable.update_many( - where={"user_id": {"in": user_ids}}, - data={"spend": {"increment": response_cost}}, - ) - if end_user_id is not None: - if existing_user_obj is None: - # if user does not exist in LiteLLM_UserTable, create a new user - existing_spend = 0 - max_user_budget = None - if litellm.max_user_budget is not None: - max_user_budget = litellm.max_user_budget - existing_user_obj = LiteLLM_UserTable( - user_id=end_user_id, - spend=0, - max_budget=max_user_budget, - user_email=None, - ) + # await prisma_client.db.litellm_usertable.update_many( + # where={"user_id": {"in": user_ids}}, + # data={"spend": {"increment": response_cost}}, + # ) + # if end_user_id is not None: + # if existing_user_obj is None: + # # if user does not exist in LiteLLM_UserTable, create a new user + # existing_spend = 0 + # max_user_budget = None + # if litellm.max_user_budget is not None: + # max_user_budget = litellm.max_user_budget + # existing_user_obj = LiteLLM_UserTable( + # user_id=end_user_id, + # spend=0, + # max_budget=max_user_budget, + # user_email=None, + # ) - else: - existing_user_obj.spend = ( - existing_user_obj.spend + response_cost - ) + # else: + # existing_user_obj.spend = ( + # existing_user_obj.spend + response_cost + # ) - user_object_json = {**existing_user_obj.json(exclude_none=True)} + # user_object_json = {**existing_user_obj.json(exclude_none=True)} - user_object_json["model_max_budget"] = json.dumps( - user_object_json["model_max_budget"] - ) - user_object_json["model_spend"] = json.dumps( - user_object_json["model_spend"] - ) + # user_object_json["model_max_budget"] = json.dumps( + # user_object_json["model_max_budget"] + # ) + # user_object_json["model_spend"] = json.dumps( + # user_object_json["model_spend"] + # ) - await prisma_client.db.litellm_usertable.upsert( - where={"user_id": end_user_id}, - data={ - "create": user_object_json, - "update": {"spend": {"increment": response_cost}}, - }, - ) + # await prisma_client.db.litellm_usertable.upsert( + # where={"user_id": end_user_id}, + # data={ + # "create": user_object_json, + # "update": {"spend": {"increment": response_cost}}, + # }, + # ) - elif custom_db_client is not None: - for id in user_ids: - if id is None: - continue - if ( - custom_db_client is not None - and id != litellm_proxy_budget_name - ): - existing_spend_obj = await custom_db_client.get_data( - key=id, table_name="user" - ) - verbose_proxy_logger.debug( - f"Updating existing_spend_obj: {existing_spend_obj}" - ) - if existing_spend_obj is None: - # if user does not exist in LiteLLM_UserTable, create a new user - existing_spend = 0 - max_user_budget = None - if litellm.max_user_budget is not None: - max_user_budget = litellm.max_user_budget - existing_spend_obj = LiteLLM_UserTable( - user_id=id, - spend=0, - max_budget=max_user_budget, - user_email=None, - ) - else: - existing_spend = existing_spend_obj.spend + # elif custom_db_client is not None: + # for id in user_ids: + # if id is None: + # continue + # if ( + # custom_db_client is not None + # and id != litellm_proxy_budget_name + # ): + # existing_spend_obj = await custom_db_client.get_data( + # key=id, table_name="user" + # ) + # verbose_proxy_logger.debug( + # f"Updating existing_spend_obj: {existing_spend_obj}" + # ) + # if existing_spend_obj is None: + # # if user does not exist in LiteLLM_UserTable, create a new user + # existing_spend = 0 + # max_user_budget = None + # if litellm.max_user_budget is not None: + # max_user_budget = litellm.max_user_budget + # existing_spend_obj = LiteLLM_UserTable( + # user_id=id, + # spend=0, + # max_budget=max_user_budget, + # user_email=None, + # ) + # else: + # existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - existing_spend_obj.spend = existing_spend + response_cost + # # Calculate the new cost by adding the existing cost and response_cost + # existing_spend_obj.spend = existing_spend + response_cost - # track cost per model, for the given user - spend_per_model = existing_spend_obj.model_spend or {} - current_model = kwargs.get("model") + # # track cost per model, for the given user + # spend_per_model = existing_spend_obj.model_spend or {} + # current_model = kwargs.get("model") - if current_model is not None and spend_per_model is not None: - if spend_per_model.get(current_model) is None: - spend_per_model[current_model] = response_cost - else: - spend_per_model[current_model] += response_cost - existing_spend_obj.model_spend = spend_per_model + # if current_model is not None and spend_per_model is not None: + # if spend_per_model.get(current_model) is None: + # spend_per_model[current_model] = response_cost + # else: + # spend_per_model[current_model] += response_cost + # existing_spend_obj.model_spend = spend_per_model - valid_token = user_api_key_cache.get_cache(key=id) - if valid_token is not None and isinstance(valid_token, dict): - user_api_key_cache.set_cache( - key=id, value=existing_spend_obj.json() - ) + # valid_token = user_api_key_cache.get_cache(key=id) + # if valid_token is not None and isinstance(valid_token, dict): + # user_api_key_cache.set_cache( + # key=id, value=existing_spend_obj.json() + # ) - verbose_proxy_logger.debug( - f"user - new cost: {existing_spend_obj.spend}, user_id: {id}" - ) - data_list.append(existing_spend_obj) + # verbose_proxy_logger.debug( + # f"user - new cost: {existing_spend_obj.spend}, user_id: {id}" + # ) + # data_list.append(existing_spend_obj) - if custom_db_client is not None and user_id is not None: - new_spend = data_list[0].spend - await custom_db_client.update_data( - key=user_id, - value={"spend": new_spend}, - table_name="user", - ) + # if custom_db_client is not None and user_id is not None: + # new_spend = data_list[0].spend + # await custom_db_client.update_data( + # key=user_id, + # value={"spend": new_spend}, + # table_name="user", + # ) except Exception as e: verbose_proxy_logger.info( "\033[91m" @@ -1323,10 +1329,12 @@ async def update_database( ) raise e - asyncio.create_task(_update_user_db()) - asyncio.create_task(_update_key_db()) - asyncio.create_task(_update_team_db()) - asyncio.create_task(_insert_spend_log_to_db()) + # asyncio.create_task(_update_user_db()) + await _update_user_db() + + # asyncio.create_task(_update_key_db()) + # asyncio.create_task(_update_team_db()) + # asyncio.create_task(_insert_spend_log_to_db()) verbose_proxy_logger.debug("Runs spend update on all tables") except Exception as e: @@ -2632,6 +2640,10 @@ async def startup_event(): scheduler.add_job( reset_budget, "interval", seconds=interval, args=[prisma_client] ) + # scheduler.add_job( + # monitor_spend_list, "interval", seconds=10, args=[prisma_client] + # ) + scheduler.add_job(update_spend, "interval", seconds=60, args=[prisma_client]) scheduler.start() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 57381bac1..17f4f9842 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -472,6 +472,11 @@ def on_backoff(details): class PrismaClient: + user_list_transactons: List = [] + key_list_transactons: List = [] + team_list_transactons: List = [] + spend_log_transactons: List = [] + def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): print_verbose( "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" @@ -1841,6 +1846,65 @@ async def reset_budget(prisma_client: PrismaClient): ) +async def update_spend( + prisma_client: PrismaClient, +): + """ + Batch write updates to db. + + Triggered every minute. + + Requires: + user_id_list: list, + keys_list: list, + team_list: list, + spend_logs: list, + """ + n_retry_times = 3 + ### UPDATE USER TABLE ### + if len(prisma_client.user_list_transactons) > 0: + for i in range(n_retry_times + 1): + try: + remaining_transactions = list(prisma_client.user_list_transactons) + while remaining_transactions: + batch_size = min(5000, len(remaining_transactions)) + batch_transactions = remaining_transactions[:batch_size] + async with prisma_client.db.tx(timeout=60000) as transaction: + async with transaction.batch_() as batcher: + for user_id_tuple in batch_transactions: + user_id, response_cost = user_id_tuple + if user_id != "litellm-proxy-budget": + batcher.litellm_usertable.update( + where={"user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) + + remaining_transactions = remaining_transactions[batch_size:] + + prisma_client.user_list_transactons = ( + [] + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + + ### UPDATE KEY TABLE ### + ### UPDATE TEAM TABLE ### + ### UPDATE SPEND LOGS TABLE ### + + +async def monitor_spend_list(prisma_client: PrismaClient): + """ + Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db + """ + if len(prisma_client.user_list_transactons) > 10000: + await update_spend(prisma_client=prisma_client) + + async def _read_request_body(request): """ Asynchronous function to read the request body and parse it as JSON or literal data. diff --git a/litellm/tests/test_update_spend.py b/litellm/tests/test_update_spend.py new file mode 100644 index 000000000..1c9dc7d4a --- /dev/null +++ b/litellm/tests/test_update_spend.py @@ -0,0 +1,97 @@ +# What is this? +## This tests the batch update spend logic on the proxy server + + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv +from fastapi import Request + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import Router, mock_completion +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache +from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token + +import pytest, logging, asyncio +import litellm, asyncio +from litellm.proxy.proxy_server import ( + new_user, + generate_key_fn, + user_api_key_auth, + user_update, + delete_key_fn, + info_key_fn, + update_key_fn, + generate_key_fn, + generate_key_helper_fn, + spend_user_fn, + spend_key_fn, + view_spend_logs, + user_info, + block_user, +) +from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend +from litellm._logging import verbose_proxy_logger + +verbose_proxy_logger.setLevel(level=logging.DEBUG) + +from litellm.proxy._types import ( + NewUserRequest, + GenerateKeyRequest, + DynamoDBArgs, + KeyRequest, + UpdateKeyRequest, + GenerateKeyRequest, + BlockUsers, +) +from litellm.proxy.utils import DBClient +from starlette.datastructures import URL +from litellm.caching import DualCache + +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + +@pytest.fixture +def prisma_client(): + from litellm.proxy.proxy_cli import append_query_params + + ### add connection pool + pool timeout args + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + + # Assuming DBClient is a class that needs to be instantiated + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + # Reset litellm.proxy.proxy_server.prisma_client to None + litellm.proxy.proxy_server.custom_db_client = None + litellm.proxy.proxy_server.litellm_proxy_budget_name = ( + f"litellm-proxy-budget-{time.time()}" + ) + litellm.proxy.proxy_server.user_custom_key_generate = None + + return prisma_client + + +@pytest.mark.asyncio +async def test_batch_update_spend(prisma_client): + prisma_client.user_list_transactons.append(("test-litellm-user-5", 23)) + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + await update_spend(prisma_client=litellm.proxy.proxy_server.prisma_client) + + raise Exception("it worked!") From 8fefe625d950788d4dbf5bd65955bd641a2f95fa Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Mar 2024 16:47:02 -0700 Subject: [PATCH 2/7] fix(proxy/utils.py): batch writing updates to db --- litellm/proxy/proxy_server.py | 227 +++----------------- litellm/proxy/tests/load_test_completion.py | 6 +- litellm/proxy/utils.py | 125 +++++++++-- 3 files changed, 141 insertions(+), 217 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 78f37ad34..588846bd3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -996,10 +996,8 @@ async def _PROXY_track_cost_callback( ) litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request") or {} - user_id = proxy_server_request.get("body", {}).get("user", None) - user_id = user_id or kwargs["litellm_params"]["metadata"].get( - "user_api_key_user_id", None - ) + end_user_id = proxy_server_request.get("body", {}).get("user", None) + user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None) if kwargs.get("response_cost", None) is not None: response_cost = kwargs["response_cost"] @@ -1013,9 +1011,6 @@ async def _PROXY_track_cost_callback( f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" ) - verbose_proxy_logger.info( - f"response_cost {response_cost}, for user_id {user_id}" - ) verbose_proxy_logger.debug( f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}" ) @@ -1025,6 +1020,7 @@ async def _PROXY_track_cost_callback( token=user_api_key, response_cost=response_cost, user_id=user_id, + end_user_id=end_user_id, team_id=team_id, kwargs=kwargs, completion_response=completion_response, @@ -1066,6 +1062,7 @@ async def update_database( token, response_cost, user_id=None, + end_user_id=None, team_id=None, kwargs=None, completion_response=None, @@ -1076,6 +1073,10 @@ async def update_database( verbose_proxy_logger.info( f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" ) + if isinstance(token, str) and token.startswith("sk-"): + hashed_token = hash_token(token=token) + else: + hashed_token = token ### UPDATE USER SPEND ### async def _update_user_db(): @@ -1084,11 +1085,6 @@ async def update_database( - Update litellm-proxy-budget row (global proxy spend) """ ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db - end_user_id = None - if isinstance(token, str) and token.startswith("sk-"): - hashed_token = hash_token(token=token) - else: - hashed_token = token existing_token_obj = await user_api_key_cache.async_get_cache( key=hashed_token ) @@ -1097,119 +1093,24 @@ async def update_database( existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if existing_user_obj is not None and isinstance(existing_user_obj, dict): existing_user_obj = LiteLLM_UserTable(**existing_user_obj) - if existing_token_obj.user_id != user_id: # an end-user id was passed in - end_user_id = user_id - user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name] - data_list = [] try: if prisma_client is not None: # update - user_ids = [user_id, litellm_proxy_budget_name] + user_ids = [user_id] + if ( + litellm.max_budget > 0 + ): # track global proxy budget, if user set max budget + user_ids.append(litellm_proxy_budget_name) ### KEY CHANGE ### for _id in user_ids: - prisma_client.user_list_transactons.append((_id, response_cost)) - ###### - ## do a group update for the user-id of the key + global proxy budget - # await prisma_client.db.litellm_usertable.update_many( - # where={"user_id": {"in": user_ids}}, - # data={"spend": {"increment": response_cost}}, - # ) - # if end_user_id is not None: - # if existing_user_obj is None: - # # if user does not exist in LiteLLM_UserTable, create a new user - # existing_spend = 0 - # max_user_budget = None - # if litellm.max_user_budget is not None: - # max_user_budget = litellm.max_user_budget - # existing_user_obj = LiteLLM_UserTable( - # user_id=end_user_id, - # spend=0, - # max_budget=max_user_budget, - # user_email=None, - # ) - - # else: - # existing_user_obj.spend = ( - # existing_user_obj.spend + response_cost - # ) - - # user_object_json = {**existing_user_obj.json(exclude_none=True)} - - # user_object_json["model_max_budget"] = json.dumps( - # user_object_json["model_max_budget"] - # ) - # user_object_json["model_spend"] = json.dumps( - # user_object_json["model_spend"] - # ) - - # await prisma_client.db.litellm_usertable.upsert( - # where={"user_id": end_user_id}, - # data={ - # "create": user_object_json, - # "update": {"spend": {"increment": response_cost}}, - # }, - # ) - - # elif custom_db_client is not None: - # for id in user_ids: - # if id is None: - # continue - # if ( - # custom_db_client is not None - # and id != litellm_proxy_budget_name - # ): - # existing_spend_obj = await custom_db_client.get_data( - # key=id, table_name="user" - # ) - # verbose_proxy_logger.debug( - # f"Updating existing_spend_obj: {existing_spend_obj}" - # ) - # if existing_spend_obj is None: - # # if user does not exist in LiteLLM_UserTable, create a new user - # existing_spend = 0 - # max_user_budget = None - # if litellm.max_user_budget is not None: - # max_user_budget = litellm.max_user_budget - # existing_spend_obj = LiteLLM_UserTable( - # user_id=id, - # spend=0, - # max_budget=max_user_budget, - # user_email=None, - # ) - # else: - # existing_spend = existing_spend_obj.spend - - # # Calculate the new cost by adding the existing cost and response_cost - # existing_spend_obj.spend = existing_spend + response_cost - - # # track cost per model, for the given user - # spend_per_model = existing_spend_obj.model_spend or {} - # current_model = kwargs.get("model") - - # if current_model is not None and spend_per_model is not None: - # if spend_per_model.get(current_model) is None: - # spend_per_model[current_model] = response_cost - # else: - # spend_per_model[current_model] += response_cost - # existing_spend_obj.model_spend = spend_per_model - - # valid_token = user_api_key_cache.get_cache(key=id) - # if valid_token is not None and isinstance(valid_token, dict): - # user_api_key_cache.set_cache( - # key=id, value=existing_spend_obj.json() - # ) - - # verbose_proxy_logger.debug( - # f"user - new cost: {existing_spend_obj.spend}, user_id: {id}" - # ) - # data_list.append(existing_spend_obj) - - # if custom_db_client is not None and user_id is not None: - # new_spend = data_list[0].spend - # await custom_db_client.update_data( - # key=user_id, - # value={"spend": new_spend}, - # table_name="user", - # ) + prisma_client.user_list_transactons[_id] = ( + response_cost + + prisma_client.user_list_transactons.get(_id, 0) + ) + if end_user_id is not None: + prisma_client.end_user_list_transactons[end_user_id] = ( + response_cost + + prisma_client.user_list_transactons.get(end_user_id, 0) + ) except Exception as e: verbose_proxy_logger.info( "\033[91m" @@ -1220,38 +1121,13 @@ async def update_database( async def _update_key_db(): try: verbose_proxy_logger.debug( - f"adding spend to key db. Response cost: {response_cost}. Token: {token}." + f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}." ) if prisma_client is not None: - await prisma_client.db.litellm_verificationtoken.update( - where={"token": token}, - data={"spend": {"increment": response_cost}}, + prisma_client.key_list_transactons[hashed_token] = ( + response_cost + + prisma_client.key_list_transactons.get(hashed_token, 0) ) - elif custom_db_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await custom_db_client.get_data( - key=token, table_name="key" - ) - verbose_proxy_logger.debug( - f"_update_key_db existing spend: {existing_spend_obj}" - ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await custom_db_client.update_data( - key=token, value={"spend": new_spend}, table_name="key" - ) - - valid_token = user_api_key_cache.get_cache(key=token) - if valid_token is not None: - valid_token.spend = new_spend - user_api_key_cache.set_cache(key=token, value=valid_token) except Exception as e: verbose_proxy_logger.info( f"Update Key DB Call failed to execute - {str(e)}\n{traceback.format_exc()}" @@ -1273,16 +1149,13 @@ async def update_database( payload["spend"] = response_cost if prisma_client is not None: await prisma_client.insert_data(data=payload, table_name="spend") - elif custom_db_client is not None: - await custom_db_client.insert_data(payload, table_name="spend") - except Exception as e: verbose_proxy_logger.info( f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" ) raise e - ### UPDATE KEY SPEND ### + ### UPDATE TEAM SPEND ### async def _update_team_db(): try: verbose_proxy_logger.debug( @@ -1294,46 +1167,19 @@ async def update_database( ) return if prisma_client is not None: - await prisma_client.db.litellm_teamtable.update( - where={"team_id": team_id}, - data={"spend": {"increment": response_cost}}, + prisma_client.team_list_transactons[team_id] = ( + response_cost + + prisma_client.team_list_transactons.get(team_id, 0) ) - elif custom_db_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await custom_db_client.get_data( - key=token, table_name="key" - ) - verbose_proxy_logger.debug( - f"_update_key_db existing spend: {existing_spend_obj}" - ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await custom_db_client.update_data( - key=token, value={"spend": new_spend}, table_name="key" - ) - - valid_token = user_api_key_cache.get_cache(key=token) - if valid_token is not None: - valid_token.spend = new_spend - user_api_key_cache.set_cache(key=token, value=valid_token) except Exception as e: verbose_proxy_logger.info( f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" ) raise e - # asyncio.create_task(_update_user_db()) - await _update_user_db() - - # asyncio.create_task(_update_key_db()) - # asyncio.create_task(_update_team_db()) + asyncio.create_task(_update_user_db()) + asyncio.create_task(_update_key_db()) + asyncio.create_task(_update_team_db()) # asyncio.create_task(_insert_spend_log_to_db()) verbose_proxy_logger.debug("Runs spend update on all tables") @@ -2237,7 +2083,6 @@ async def generate_key_helper_fn( saved_token["expires"] = saved_token["expires"].isoformat() if prisma_client is not None: ## CREATE USER (If necessary) - verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") if query_type == "insert_data": user_row = await prisma_client.insert_data( data=user_data, table_name="user" @@ -2575,7 +2420,6 @@ async def startup_event(): # add master key to db if os.getenv("PROXY_ADMIN_ID", None) is not None: litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID") - asyncio.create_task( generate_key_helper_fn( duration=None, @@ -2640,10 +2484,7 @@ async def startup_event(): scheduler.add_job( reset_budget, "interval", seconds=interval, args=[prisma_client] ) - # scheduler.add_job( - # monitor_spend_list, "interval", seconds=10, args=[prisma_client] - # ) - scheduler.add_job(update_spend, "interval", seconds=60, args=[prisma_client]) + scheduler.add_job(update_spend, "interval", seconds=10, args=[prisma_client]) scheduler.start() diff --git a/litellm/proxy/tests/load_test_completion.py b/litellm/proxy/tests/load_test_completion.py index 3f0da2e94..9450c1cb5 100644 --- a/litellm/proxy/tests/load_test_completion.py +++ b/litellm/proxy/tests/load_test_completion.py @@ -7,6 +7,7 @@ from dotenv import load_dotenv litellm_client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key="sk-1234") + async def litellm_completion(): # Your existing code for litellm_completion goes here try: @@ -18,6 +19,7 @@ async def litellm_completion(): "content": f"{text}. Who was alexander the great? {uuid.uuid4()}", } ], + user="my-new-end-user-1", ) return response @@ -29,9 +31,9 @@ async def litellm_completion(): async def main(): - for i in range(6): + for i in range(3): start = time.time() - n = 20 # Number of concurrent tasks + n = 10 # Number of concurrent tasks tasks = [litellm_completion() for _ in range(n)] chat_completions = await asyncio.gather(*tasks) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 17f4f9842..cc41c8ec8 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -7,6 +7,10 @@ from litellm.proxy._types import ( LiteLLM_VerificationToken, LiteLLM_VerificationTokenView, LiteLLM_SpendLogs, + LiteLLM_UserTable, + LiteLLM_EndUserTable, + LiteLLM_TeamTable, + Member, ) from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import ( @@ -472,9 +476,10 @@ def on_backoff(details): class PrismaClient: - user_list_transactons: List = [] - key_list_transactons: List = [] - team_list_transactons: List = [] + user_list_transactons: dict = {} + end_user_list_transactons: dict = {} + key_list_transactons: dict = {} + team_list_transactons: dict = {} spend_log_transactons: List = [] def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): @@ -1855,34 +1860,62 @@ async def update_spend( Triggered every minute. Requires: - user_id_list: list, + user_id_list: dict, keys_list: list, team_list: list, spend_logs: list, """ + verbose_proxy_logger.debug( + f"ENTERS UPDATE SPEND - len(prisma_client.user_list_transactons.keys()): {len(prisma_client.user_list_transactons.keys())}" + ) n_retry_times = 3 ### UPDATE USER TABLE ### - if len(prisma_client.user_list_transactons) > 0: + if len(prisma_client.user_list_transactons.keys()) > 0: for i in range(n_retry_times + 1): try: - remaining_transactions = list(prisma_client.user_list_transactons) - while remaining_transactions: - batch_size = min(5000, len(remaining_transactions)) - batch_transactions = remaining_transactions[:batch_size] - async with prisma_client.db.tx(timeout=60000) as transaction: - async with transaction.batch_() as batcher: - for user_id_tuple in batch_transactions: - user_id, response_cost = user_id_tuple - if user_id != "litellm-proxy-budget": - batcher.litellm_usertable.update( - where={"user_id": user_id}, - data={"spend": {"increment": response_cost}}, - ) - - remaining_transactions = remaining_transactions[batch_size:] - + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + user_id, + response_cost, + ) in prisma_client.user_list_transactons.items(): + batcher.litellm_usertable.update_many( + where={"user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) prisma_client.user_list_transactons = ( - [] + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + + ### UPDATE END-USER TABLE ### + if len(prisma_client.end_user_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + end_user_id, + response_cost, + ) in prisma_client.end_user_list_transactons.items(): + max_user_budget = None + if litellm.max_user_budget is not None: + max_user_budget = litellm.max_user_budget + new_user_obj = LiteLLM_EndUserTable( + user_id=end_user_id, spend=response_cost, blocked=False + ) + batcher.litellm_endusertable.update_many( + where={"user_id": end_user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.end_user_list_transactons = ( + {} ) # Clear the remaining transactions after processing all batches in the loop. except httpx.ReadTimeout: if i >= n_retry_times: # If we've reached the maximum number of retries @@ -1893,7 +1926,55 @@ async def update_spend( raise e ### UPDATE KEY TABLE ### + if len(prisma_client.key_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + token, + response_cost, + ) in prisma_client.key_list_transactons.items(): + batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists + where={"token": token}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.key_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + ### UPDATE TEAM TABLE ### + if len(prisma_client.team_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + team_id, + response_cost, + ) in prisma_client.team_list_transactons.items(): + batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists + where={"team_id": team_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.team_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + ### UPDATE SPEND LOGS TABLE ### From f588bff69ba74ad8c45bc490bbbbe6ac4cf0c7d4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Mar 2024 20:26:28 -0700 Subject: [PATCH 3/7] fix(proxy_server.py): fix spend log update --- litellm/proxy/proxy_server.py | 21 +++++++++++++++++---- proxy_server_config.yaml | 1 + tests/test_keys.py | 14 +++++++++++++- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 588846bd3..930ffb558 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -279,6 +279,7 @@ litellm_proxy_admin_name = "default_user_id" ui_access_mode: Literal["admin", "all"] = "all" proxy_budget_rescheduler_min_time = 597 proxy_budget_rescheduler_max_time = 605 +proxy_batch_write_at = 60 # in seconds litellm_master_key_hash = None ### INITIALIZE GLOBAL LOGGING OBJECT ### proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) @@ -1138,7 +1139,6 @@ async def update_database( async def _insert_spend_log_to_db(): try: # Helper to generate payload to log - verbose_proxy_logger.debug("inserting spend log to db") payload = get_logging_payload( kwargs=kwargs, response_obj=completion_response, @@ -1150,7 +1150,7 @@ async def update_database( if prisma_client is not None: await prisma_client.insert_data(data=payload, table_name="spend") except Exception as e: - verbose_proxy_logger.info( + verbose_proxy_logger.debug( f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" ) raise e @@ -1181,6 +1181,7 @@ async def update_database( asyncio.create_task(_update_key_db()) asyncio.create_task(_update_team_db()) # asyncio.create_task(_insert_spend_log_to_db()) + await _insert_spend_log_to_db() verbose_proxy_logger.debug("Runs spend update on all tables") except Exception as e: @@ -1499,7 +1500,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -1855,6 +1856,10 @@ class ProxyConfig: proxy_budget_rescheduler_max_time = general_settings.get( "proxy_budget_rescheduler_max_time", proxy_budget_rescheduler_max_time ) + ## BATCH WRITER ## + proxy_batch_write_at = general_settings.get( + "proxy_batch_write_at", proxy_batch_write_at + ) ### BACKGROUND HEALTH CHECKS ### # Enable background health checks use_background_health_checks = general_settings.get( @@ -2481,10 +2486,18 @@ async def startup_event(): interval = random.randint( proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time ) # random interval, so multiple workers avoid resetting budget at the same time + batch_writing_interval = random.randint( + proxy_batch_write_at - 3, proxy_batch_write_at + 3 + ) # random interval, so multiple workers avoid batch writing at the same time scheduler.add_job( reset_budget, "interval", seconds=interval, args=[prisma_client] ) - scheduler.add_job(update_spend, "interval", seconds=10, args=[prisma_client]) + scheduler.add_job( + update_spend, + "interval", + seconds=batch_writing_interval, + args=[prisma_client], + ) scheduler.start() diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 34a61994b..d31218b8d 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -50,6 +50,7 @@ general_settings: master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) proxy_budget_rescheduler_min_time: 60 proxy_budget_rescheduler_max_time: 64 + proxy_batch_write_at: 1 # database_url: "postgresql://:@:/" # [OPTIONAL] use for token-based auth to proxy # environment_variables: diff --git a/tests/test_keys.py b/tests/test_keys.py index cba960aca..0419e9f8a 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -329,6 +329,16 @@ async def test_key_info_spend_values(): - make completion call - assert cost is expected value """ + + async def retry_request(func, *args, _max_attempts=5, **kwargs): + for attempt in range(_max_attempts): + try: + return await func(*args, **kwargs) + except aiohttp.client_exceptions.ClientOSError as e: + if attempt + 1 == _max_attempts: + raise # re-raise the last ClientOSError if all attempts failed + print(f"Attempt {attempt+1} failed, retrying...") + async with aiohttp.ClientSession() as session: ## Test Spend Update ## # completion @@ -336,7 +346,9 @@ async def test_key_info_spend_values(): key = key_gen["key"] response = await chat_completion(session=session, key=key) await asyncio.sleep(5) - spend_logs = await get_spend_logs(session=session, request_id=response["id"]) + spend_logs = await retry_request( + get_spend_logs, session=session, request_id=response["id"] + ) print(f"spend_logs: {spend_logs}") completion_tokens = spend_logs[0]["completion_tokens"] prompt_tokens = spend_logs[0]["prompt_tokens"] From 0b39bc141b1c2fe7848fdfa1536284d051a93871 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Mar 2024 20:38:21 -0700 Subject: [PATCH 4/7] test(test_update_spend.py): fix test --- litellm/tests/test_update_spend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/litellm/tests/test_update_spend.py b/litellm/tests/test_update_spend.py index 1c9dc7d4a..0fd5d9bcf 100644 --- a/litellm/tests/test_update_spend.py +++ b/litellm/tests/test_update_spend.py @@ -88,10 +88,8 @@ def prisma_client(): @pytest.mark.asyncio async def test_batch_update_spend(prisma_client): - prisma_client.user_list_transactons.append(("test-litellm-user-5", 23)) + prisma_client.user_list_transactons["test-litellm-user-5"] = 23 setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") await litellm.proxy.proxy_server.prisma_client.connect() await update_spend(prisma_client=litellm.proxy.proxy_server.prisma_client) - - raise Exception("it worked!") From 7eaddaef10a8a64bf547430d7d3468570cf986fc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Mar 2024 21:16:28 -0700 Subject: [PATCH 5/7] refactor(proxy_server.py): re-add custom db client logic - prevent regressions --- litellm/proxy/proxy_server.py | 113 ++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 930ffb558..2e19903c4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1094,6 +1094,7 @@ async def update_database( existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if existing_user_obj is not None and isinstance(existing_user_obj, dict): existing_user_obj = LiteLLM_UserTable(**existing_user_obj) + data_list = [] try: if prisma_client is not None: # update user_ids = [user_id] @@ -1112,6 +1113,68 @@ async def update_database( response_cost + prisma_client.user_list_transactons.get(end_user_id, 0) ) + elif custom_db_client is not None: + for id in user_ids: + if id is None: + continue + if ( + custom_db_client is not None + and id != litellm_proxy_budget_name + ): + existing_spend_obj = await custom_db_client.get_data( + key=id, table_name="user" + ) + verbose_proxy_logger.debug( + f"Updating existing_spend_obj: {existing_spend_obj}" + ) + if existing_spend_obj is None: + # if user does not exist in LiteLLM_UserTable, create a new user + existing_spend = 0 + max_user_budget = None + if litellm.max_user_budget is not None: + max_user_budget = litellm.max_user_budget + existing_spend_obj = LiteLLM_UserTable( + user_id=id, + spend=0, + max_budget=max_user_budget, + user_email=None, + ) + else: + existing_spend = existing_spend_obj.spend + + # Calculate the new cost by adding the existing cost and response_cost + existing_spend_obj.spend = existing_spend + response_cost + + # track cost per model, for the given user + spend_per_model = existing_spend_obj.model_spend or {} + current_model = kwargs.get("model") + + if current_model is not None and spend_per_model is not None: + if spend_per_model.get(current_model) is None: + spend_per_model[current_model] = response_cost + else: + spend_per_model[current_model] += response_cost + existing_spend_obj.model_spend = spend_per_model + + valid_token = user_api_key_cache.get_cache(key=id) + if valid_token is not None and isinstance(valid_token, dict): + user_api_key_cache.set_cache( + key=id, value=existing_spend_obj.json() + ) + + verbose_proxy_logger.debug( + f"user - new cost: {existing_spend_obj.spend}, user_id: {id}" + ) + data_list.append(existing_spend_obj) + + if custom_db_client is not None and user_id is not None: + new_spend = data_list[0].spend + await custom_db_client.update_data( + key=user_id, + value={"spend": new_spend}, + table_name="user", + ) + except Exception as e: verbose_proxy_logger.info( "\033[91m" @@ -1129,6 +1192,31 @@ async def update_database( response_cost + prisma_client.key_list_transactons.get(hashed_token, 0) ) + elif custom_db_client is not None: + # Fetch the existing cost for the given token + existing_spend_obj = await custom_db_client.get_data( + key=token, table_name="key" + ) + verbose_proxy_logger.debug( + f"_update_key_db existing spend: {existing_spend_obj}" + ) + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + verbose_proxy_logger.debug(f"new cost: {new_spend}") + # Update the cost column for the given token + await custom_db_client.update_data( + key=token, value={"spend": new_spend}, table_name="key" + ) + + valid_token = user_api_key_cache.get_cache(key=token) + if valid_token is not None: + valid_token.spend = new_spend + user_api_key_cache.set_cache(key=token, value=valid_token) except Exception as e: verbose_proxy_logger.info( f"Update Key DB Call failed to execute - {str(e)}\n{traceback.format_exc()}" @@ -1171,6 +1259,31 @@ async def update_database( response_cost + prisma_client.team_list_transactons.get(team_id, 0) ) + elif custom_db_client is not None: + # Fetch the existing cost for the given token + existing_spend_obj = await custom_db_client.get_data( + key=token, table_name="key" + ) + verbose_proxy_logger.debug( + f"_update_key_db existing spend: {existing_spend_obj}" + ) + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + verbose_proxy_logger.debug(f"new cost: {new_spend}") + # Update the cost column for the given token + await custom_db_client.update_data( + key=token, value={"spend": new_spend}, table_name="key" + ) + + valid_token = user_api_key_cache.get_cache(key=token) + if valid_token is not None: + valid_token.spend = new_spend + user_api_key_cache.set_cache(key=token, value=valid_token) except Exception as e: verbose_proxy_logger.info( f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" From 2827acc48787e894a56aa4d16afd904f050e5d52 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Mar 2024 21:27:32 -0700 Subject: [PATCH 6/7] refactor(main.py): trigger new build --- litellm/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/main.py b/litellm/main.py index b20858d89..3a9fed77e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -12,6 +12,7 @@ from typing import Any, Literal, Union, BinaryIO from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy + import httpx import litellm from ._logging import verbose_logger From 2fdc20f549cfbf28d5da5cd3e92a96fd041d868b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Mar 2024 21:33:08 -0700 Subject: [PATCH 7/7] test: handle vertex ai rate limit errors --- .../tests/test_amazing_vertex_completion.py | 65 ++++++++++++------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 67db9b61c..566345719 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -374,7 +374,8 @@ def test_gemini_pro_vision_base64(): print(resp) prompt_tokens = resp.usage.prompt_tokens - + except litellm.RateLimitError as e: + pass except Exception as e: if "500 Internal error encountered.'" in str(e): pass @@ -419,33 +420,43 @@ def test_gemini_pro_function_calling(): @pytest.mark.asyncio async def test_gemini_pro_async_function_calling(): load_vertex_ai_credentials() - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", + try: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + "required": ["location"], }, - "required": ["location"], }, - }, - } - ] - messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] - completion = await litellm.acompletion( - model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" - ) - print(f"completion: {completion}") - assert completion.choices[0].message.content is None - assert len(completion.choices[0].message.tool_calls) == 1 + } + ] + messages = [ + {"role": "user", "content": "What's the weather like in Boston today?"} + ] + completion = await litellm.acompletion( + model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" + ) + print(f"completion: {completion}") + assert completion.choices[0].message.content is None + assert len(completion.choices[0].message.tool_calls) == 1 + except litellm.RateLimitError as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") # raise Exception("it worked!") @@ -461,6 +472,8 @@ def test_vertexai_embedding(): input=["good morning from litellm", "this is another item"], ) print(f"response:", response) + except litellm.RateLimitError as e: + pass except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -475,6 +488,8 @@ async def test_vertexai_aembedding(): input=["good morning from litellm", "this is another item"], ) print(f"response: {response}") + except litellm.RateLimitError as e: + pass except Exception as e: pytest.fail(f"Error occurred: {e}")