diff --git a/litellm/main.py b/litellm/main.py index 86d952cbf..cc11863c6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -12,6 +12,7 @@ from typing import Any, Literal, Union, BinaryIO from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy + import httpx import litellm from ._logging import verbose_logger diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 1c41d79fc..bd277bbdf 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -14,7 +14,8 @@ litellm_settings: cache_params: type: redis callbacks: ["batch_redis_requests"] + # success_callbacks: ["langfuse"] general_settings: master_key: sk-1234 - # database_url: "postgresql://krrishdholakia:9yQkKWiB8vVs@ep-icy-union-a5j4dwls.us-east-2.aws.neon.tech/neondb?sslmode=require" \ No newline at end of file + database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require" \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 529ab8b7e..98394f60e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -96,6 +96,8 @@ from litellm.proxy.utils import ( _is_user_proxy_admin, _is_projected_spend_over_limit, _get_projected_spend_over_limit, + update_spend, + monitor_spend_list, ) from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager @@ -277,6 +279,7 @@ litellm_proxy_admin_name = "default_user_id" ui_access_mode: Literal["admin", "all"] = "all" proxy_budget_rescheduler_min_time = 597 proxy_budget_rescheduler_max_time = 605 +proxy_batch_write_at = 60 # in seconds litellm_master_key_hash = None ### INITIALIZE GLOBAL LOGGING OBJECT ### proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) @@ -995,10 +998,8 @@ async def _PROXY_track_cost_callback( ) litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request") or {} - user_id = proxy_server_request.get("body", {}).get("user", None) - user_id = user_id or kwargs["litellm_params"]["metadata"].get( - "user_api_key_user_id", None - ) + end_user_id = proxy_server_request.get("body", {}).get("user", None) + user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None) if kwargs.get("response_cost", None) is not None: response_cost = kwargs["response_cost"] @@ -1012,9 +1013,6 @@ async def _PROXY_track_cost_callback( f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" ) - verbose_proxy_logger.info( - f"response_cost {response_cost}, for user_id {user_id}" - ) verbose_proxy_logger.debug( f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}" ) @@ -1024,6 +1022,7 @@ async def _PROXY_track_cost_callback( token=user_api_key, response_cost=response_cost, user_id=user_id, + end_user_id=end_user_id, team_id=team_id, kwargs=kwargs, completion_response=completion_response, @@ -1065,6 +1064,7 @@ async def update_database( token, response_cost, user_id=None, + end_user_id=None, team_id=None, kwargs=None, completion_response=None, @@ -1075,6 +1075,10 @@ async def update_database( verbose_proxy_logger.info( f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" ) + if isinstance(token, str) and token.startswith("sk-"): + hashed_token = hash_token(token=token) + else: + hashed_token = token ### UPDATE USER SPEND ### async def _update_user_db(): @@ -1083,11 +1087,6 @@ async def update_database( - Update litellm-proxy-budget row (global proxy spend) """ ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db - end_user_id = None - if isinstance(token, str) and token.startswith("sk-"): - hashed_token = hash_token(token=token) - else: - hashed_token = token existing_token_obj = await user_api_key_cache.async_get_cache( key=hashed_token ) @@ -1096,54 +1095,25 @@ async def update_database( existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if existing_user_obj is not None and isinstance(existing_user_obj, dict): existing_user_obj = LiteLLM_UserTable(**existing_user_obj) - if existing_token_obj.user_id != user_id: # an end-user id was passed in - end_user_id = user_id - user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name] data_list = [] try: if prisma_client is not None: # update - user_ids = [user_id, litellm_proxy_budget_name] - ## do a group update for the user-id of the key + global proxy budget - await prisma_client.db.litellm_usertable.update_many( - where={"user_id": {"in": user_ids}}, - data={"spend": {"increment": response_cost}}, - ) + user_ids = [user_id] + if ( + litellm.max_budget > 0 + ): # track global proxy budget, if user set max budget + user_ids.append(litellm_proxy_budget_name) + ### KEY CHANGE ### + for _id in user_ids: + prisma_client.user_list_transactons[_id] = ( + response_cost + + prisma_client.user_list_transactons.get(_id, 0) + ) if end_user_id is not None: - if existing_user_obj is None: - # if user does not exist in LiteLLM_UserTable, create a new user - existing_spend = 0 - max_user_budget = None - if litellm.max_user_budget is not None: - max_user_budget = litellm.max_user_budget - existing_user_obj = LiteLLM_UserTable( - user_id=end_user_id, - spend=0, - max_budget=max_user_budget, - user_email=None, - ) - - else: - existing_user_obj.spend = ( - existing_user_obj.spend + response_cost - ) - - user_object_json = {**existing_user_obj.json(exclude_none=True)} - - user_object_json["model_max_budget"] = json.dumps( - user_object_json["model_max_budget"] + prisma_client.end_user_list_transactons[end_user_id] = ( + response_cost + + prisma_client.user_list_transactons.get(end_user_id, 0) ) - user_object_json["model_spend"] = json.dumps( - user_object_json["model_spend"] - ) - - await prisma_client.db.litellm_usertable.upsert( - where={"user_id": end_user_id}, - data={ - "create": user_object_json, - "update": {"spend": {"increment": response_cost}}, - }, - ) - elif custom_db_client is not None: for id in user_ids: if id is None: @@ -1205,6 +1175,7 @@ async def update_database( value={"spend": new_spend}, table_name="user", ) + except Exception as e: verbose_proxy_logger.info( "\033[91m" @@ -1215,12 +1186,12 @@ async def update_database( async def _update_key_db(): try: verbose_proxy_logger.debug( - f"adding spend to key db. Response cost: {response_cost}. Token: {token}." + f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}." ) if prisma_client is not None: - await prisma_client.db.litellm_verificationtoken.update( - where={"token": token}, - data={"spend": {"increment": response_cost}}, + prisma_client.key_list_transactons[hashed_token] = ( + response_cost + + prisma_client.key_list_transactons.get(hashed_token, 0) ) elif custom_db_client is not None: # Fetch the existing cost for the given token @@ -1257,7 +1228,6 @@ async def update_database( async def _insert_spend_log_to_db(): try: # Helper to generate payload to log - verbose_proxy_logger.debug("inserting spend log to db") payload = get_logging_payload( kwargs=kwargs, response_obj=completion_response, @@ -1268,16 +1238,13 @@ async def update_database( payload["spend"] = response_cost if prisma_client is not None: await prisma_client.insert_data(data=payload, table_name="spend") - elif custom_db_client is not None: - await custom_db_client.insert_data(payload, table_name="spend") - except Exception as e: - verbose_proxy_logger.info( + verbose_proxy_logger.debug( f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" ) raise e - ### UPDATE KEY SPEND ### + ### UPDATE TEAM SPEND ### async def _update_team_db(): try: verbose_proxy_logger.debug( @@ -1289,9 +1256,9 @@ async def update_database( ) return if prisma_client is not None: - await prisma_client.db.litellm_teamtable.update( - where={"team_id": team_id}, - data={"spend": {"increment": response_cost}}, + prisma_client.team_list_transactons[team_id] = ( + response_cost + + prisma_client.team_list_transactons.get(team_id, 0) ) elif custom_db_client is not None: # Fetch the existing cost for the given token @@ -1327,7 +1294,8 @@ async def update_database( asyncio.create_task(_update_user_db()) asyncio.create_task(_update_key_db()) asyncio.create_task(_update_team_db()) - asyncio.create_task(_insert_spend_log_to_db()) + # asyncio.create_task(_insert_spend_log_to_db()) + await _insert_spend_log_to_db() verbose_proxy_logger.debug("Runs spend update on all tables") except Exception as e: @@ -1646,7 +1614,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -2010,6 +1978,10 @@ class ProxyConfig: proxy_budget_rescheduler_max_time = general_settings.get( "proxy_budget_rescheduler_max_time", proxy_budget_rescheduler_max_time ) + ## BATCH WRITER ## + proxy_batch_write_at = general_settings.get( + "proxy_batch_write_at", proxy_batch_write_at + ) ### BACKGROUND HEALTH CHECKS ### # Enable background health checks use_background_health_checks = general_settings.get( @@ -2238,7 +2210,6 @@ async def generate_key_helper_fn( saved_token["expires"] = saved_token["expires"].isoformat() if prisma_client is not None: ## CREATE USER (If necessary) - verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") if query_type == "insert_data": user_row = await prisma_client.insert_data( data=user_data, table_name="user" @@ -2576,7 +2547,6 @@ async def startup_event(): # add master key to db if os.getenv("PROXY_ADMIN_ID", None) is not None: litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID") - asyncio.create_task( generate_key_helper_fn( duration=None, @@ -2638,9 +2608,18 @@ async def startup_event(): interval = random.randint( proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time ) # random interval, so multiple workers avoid resetting budget at the same time + batch_writing_interval = random.randint( + proxy_batch_write_at - 3, proxy_batch_write_at + 3 + ) # random interval, so multiple workers avoid batch writing at the same time scheduler.add_job( reset_budget, "interval", seconds=interval, args=[prisma_client] ) + scheduler.add_job( + update_spend, + "interval", + seconds=batch_writing_interval, + args=[prisma_client], + ) scheduler.start() diff --git a/litellm/proxy/tests/load_test_completion.py b/litellm/proxy/tests/load_test_completion.py index 3f0da2e94..9450c1cb5 100644 --- a/litellm/proxy/tests/load_test_completion.py +++ b/litellm/proxy/tests/load_test_completion.py @@ -7,6 +7,7 @@ from dotenv import load_dotenv litellm_client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key="sk-1234") + async def litellm_completion(): # Your existing code for litellm_completion goes here try: @@ -18,6 +19,7 @@ async def litellm_completion(): "content": f"{text}. Who was alexander the great? {uuid.uuid4()}", } ], + user="my-new-end-user-1", ) return response @@ -29,9 +31,9 @@ async def litellm_completion(): async def main(): - for i in range(6): + for i in range(3): start = time.time() - n = 20 # Number of concurrent tasks + n = 10 # Number of concurrent tasks tasks = [litellm_completion() for _ in range(n)] chat_completions = await asyncio.gather(*tasks) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 57381bac1..cc41c8ec8 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -7,6 +7,10 @@ from litellm.proxy._types import ( LiteLLM_VerificationToken, LiteLLM_VerificationTokenView, LiteLLM_SpendLogs, + LiteLLM_UserTable, + LiteLLM_EndUserTable, + LiteLLM_TeamTable, + Member, ) from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import ( @@ -472,6 +476,12 @@ def on_backoff(details): class PrismaClient: + user_list_transactons: dict = {} + end_user_list_transactons: dict = {} + key_list_transactons: dict = {} + team_list_transactons: dict = {} + spend_log_transactons: List = [] + def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): print_verbose( "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" @@ -1841,6 +1851,141 @@ async def reset_budget(prisma_client: PrismaClient): ) +async def update_spend( + prisma_client: PrismaClient, +): + """ + Batch write updates to db. + + Triggered every minute. + + Requires: + user_id_list: dict, + keys_list: list, + team_list: list, + spend_logs: list, + """ + verbose_proxy_logger.debug( + f"ENTERS UPDATE SPEND - len(prisma_client.user_list_transactons.keys()): {len(prisma_client.user_list_transactons.keys())}" + ) + n_retry_times = 3 + ### UPDATE USER TABLE ### + if len(prisma_client.user_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + user_id, + response_cost, + ) in prisma_client.user_list_transactons.items(): + batcher.litellm_usertable.update_many( + where={"user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.user_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + + ### UPDATE END-USER TABLE ### + if len(prisma_client.end_user_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + end_user_id, + response_cost, + ) in prisma_client.end_user_list_transactons.items(): + max_user_budget = None + if litellm.max_user_budget is not None: + max_user_budget = litellm.max_user_budget + new_user_obj = LiteLLM_EndUserTable( + user_id=end_user_id, spend=response_cost, blocked=False + ) + batcher.litellm_endusertable.update_many( + where={"user_id": end_user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.end_user_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + + ### UPDATE KEY TABLE ### + if len(prisma_client.key_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + token, + response_cost, + ) in prisma_client.key_list_transactons.items(): + batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists + where={"token": token}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.key_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + + ### UPDATE TEAM TABLE ### + if len(prisma_client.team_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + try: + async with prisma_client.db.tx(timeout=6000) as transaction: + async with transaction.batch_() as batcher: + for ( + team_id, + response_cost, + ) in prisma_client.team_list_transactons.items(): + batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists + where={"team_id": team_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.team_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + raise e + + ### UPDATE SPEND LOGS TABLE ### + + +async def monitor_spend_list(prisma_client: PrismaClient): + """ + Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db + """ + if len(prisma_client.user_list_transactons) > 10000: + await update_spend(prisma_client=prisma_client) + + async def _read_request_body(request): """ Asynchronous function to read the request body and parse it as JSON or literal data. diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 13f0850d8..264bb7a70 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -374,7 +374,8 @@ def test_gemini_pro_vision_base64(): print(resp) prompt_tokens = resp.usage.prompt_tokens - + except litellm.RateLimitError as e: + pass except Exception as e: if "500 Internal error encountered.'" in str(e): pass @@ -457,33 +458,43 @@ def test_gemini_pro_function_calling_streaming(): @pytest.mark.asyncio async def test_gemini_pro_async_function_calling(): load_vertex_ai_credentials() - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", + try: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + "required": ["location"], }, - "required": ["location"], }, - }, - } - ] - messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] - completion = await litellm.acompletion( - model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" - ) - print(f"completion: {completion}") - assert completion.choices[0].message.content is None - assert len(completion.choices[0].message.tool_calls) == 1 + } + ] + messages = [ + {"role": "user", "content": "What's the weather like in Boston today?"} + ] + completion = await litellm.acompletion( + model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" + ) + print(f"completion: {completion}") + assert completion.choices[0].message.content is None + assert len(completion.choices[0].message.tool_calls) == 1 + except litellm.RateLimitError as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") # raise Exception("it worked!") @@ -499,6 +510,8 @@ def test_vertexai_embedding(): input=["good morning from litellm", "this is another item"], ) print(f"response:", response) + except litellm.RateLimitError as e: + pass except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -513,6 +526,8 @@ async def test_vertexai_aembedding(): input=["good morning from litellm", "this is another item"], ) print(f"response: {response}") + except litellm.RateLimitError as e: + pass except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_update_spend.py b/litellm/tests/test_update_spend.py new file mode 100644 index 000000000..0fd5d9bcf --- /dev/null +++ b/litellm/tests/test_update_spend.py @@ -0,0 +1,95 @@ +# What is this? +## This tests the batch update spend logic on the proxy server + + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv +from fastapi import Request + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import Router, mock_completion +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache +from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token + +import pytest, logging, asyncio +import litellm, asyncio +from litellm.proxy.proxy_server import ( + new_user, + generate_key_fn, + user_api_key_auth, + user_update, + delete_key_fn, + info_key_fn, + update_key_fn, + generate_key_fn, + generate_key_helper_fn, + spend_user_fn, + spend_key_fn, + view_spend_logs, + user_info, + block_user, +) +from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend +from litellm._logging import verbose_proxy_logger + +verbose_proxy_logger.setLevel(level=logging.DEBUG) + +from litellm.proxy._types import ( + NewUserRequest, + GenerateKeyRequest, + DynamoDBArgs, + KeyRequest, + UpdateKeyRequest, + GenerateKeyRequest, + BlockUsers, +) +from litellm.proxy.utils import DBClient +from starlette.datastructures import URL +from litellm.caching import DualCache + +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + +@pytest.fixture +def prisma_client(): + from litellm.proxy.proxy_cli import append_query_params + + ### add connection pool + pool timeout args + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + + # Assuming DBClient is a class that needs to be instantiated + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + # Reset litellm.proxy.proxy_server.prisma_client to None + litellm.proxy.proxy_server.custom_db_client = None + litellm.proxy.proxy_server.litellm_proxy_budget_name = ( + f"litellm-proxy-budget-{time.time()}" + ) + litellm.proxy.proxy_server.user_custom_key_generate = None + + return prisma_client + + +@pytest.mark.asyncio +async def test_batch_update_spend(prisma_client): + prisma_client.user_list_transactons["test-litellm-user-5"] = 23 + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + await update_spend(prisma_client=litellm.proxy.proxy_server.prisma_client) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 34a61994b..d31218b8d 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -50,6 +50,7 @@ general_settings: master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) proxy_budget_rescheduler_min_time: 60 proxy_budget_rescheduler_max_time: 64 + proxy_batch_write_at: 1 # database_url: "postgresql://:@:/" # [OPTIONAL] use for token-based auth to proxy # environment_variables: diff --git a/tests/test_keys.py b/tests/test_keys.py index cba960aca..0419e9f8a 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -329,6 +329,16 @@ async def test_key_info_spend_values(): - make completion call - assert cost is expected value """ + + async def retry_request(func, *args, _max_attempts=5, **kwargs): + for attempt in range(_max_attempts): + try: + return await func(*args, **kwargs) + except aiohttp.client_exceptions.ClientOSError as e: + if attempt + 1 == _max_attempts: + raise # re-raise the last ClientOSError if all attempts failed + print(f"Attempt {attempt+1} failed, retrying...") + async with aiohttp.ClientSession() as session: ## Test Spend Update ## # completion @@ -336,7 +346,9 @@ async def test_key_info_spend_values(): key = key_gen["key"] response = await chat_completion(session=session, key=key) await asyncio.sleep(5) - spend_logs = await get_spend_logs(session=session, request_id=response["id"]) + spend_logs = await retry_request( + get_spend_logs, session=session, request_id=response["id"] + ) print(f"spend_logs: {spend_logs}") completion_tokens = spend_logs[0]["completion_tokens"] prompt_tokens = spend_logs[0]["prompt_tokens"]