From b6adeec3473422e0f71fe0db8550867a00aaf88b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 6 Feb 2024 17:30:36 -0800 Subject: [PATCH 1/4] fix(proxy_server.py): prisma client fixes for high traffic --- litellm/_logging.py | 2 +- litellm/integrations/langfuse.py | 2 +- litellm/proxy/_types.py | 6 +- litellm/proxy/proxy_server.py | 235 +++++++++++++++++-------------- tests/test_keys.py | 2 +- tests/test_spend_logs.py | 91 +++++++++++- 6 files changed, 224 insertions(+), 114 deletions(-) diff --git a/litellm/_logging.py b/litellm/_logging.py index 438fa9743d..171761c3c2 100644 --- a/litellm/_logging.py +++ b/litellm/_logging.py @@ -3,7 +3,7 @@ import logging set_verbose = False # Create a handler for the logger (you may need to adapt this based on your needs) -handler = logging.StreamHandler() +handler = logging.FileHandler("log_file.txt") handler.setLevel(logging.DEBUG) # Create a formatter and set it for the handler diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 3031868ec7..bb49cb048a 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -113,7 +113,7 @@ class LangFuseLogger: elif response_obj is not None: input = prompt output = response_obj["choices"][0]["message"].json() - print(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}") + print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}") if self._is_langfuse_v2(): self._log_langfuse_v2( user_id, diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ed7237141e..ca5d4b05b1 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -391,6 +391,10 @@ class LiteLLM_SpendLogs(LiteLLMBase): startTime: Union[str, datetime, None] endTime: Union[str, datetime, None] user: Optional[str] = "" - metadata: Optional[Json] = {} + metadata: Optional[dict] = {} cache_hit: Optional[str] = "False" cache_key: Optional[str] = None + + +class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase): + response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 427bb88a9c..5e016c46be 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -432,15 +432,25 @@ async def user_api_key_auth( # Check 2. If user_id for this token is in budget ## Check 2.5 If global proxy is in budget if valid_token.user_id is not None: - if prisma_client is not None: - user_id_information = await prisma_client.get_data( - user_id_list=[valid_token.user_id, litellm_proxy_budget_name], - table_name="user", - query_type="find_all", - ) - if custom_db_client is not None: - user_id_information = await custom_db_client.get_data( - key=valid_token.user_id, table_name="user" + user_id_information = user_api_key_cache.get_cache( + key=valid_token.user_id + ) + if user_id_information is None: + if prisma_client is not None: + user_id_information = await prisma_client.get_data( + user_id_list=[ + valid_token.user_id, + litellm_proxy_budget_name, + ], + table_name="user", + query_type="find_all", + ) + if custom_db_client is not None: + user_id_information = await custom_db_client.get_data( + key=valid_token.user_id, table_name="user" + ) + user_api_key_cache.set_cache( + key=valid_token.user_id, value=user_id_information, ttl=600 ) verbose_proxy_logger.debug( @@ -544,7 +554,7 @@ async def user_api_key_auth( api_key = valid_token.token # Add hashed token to cache - user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) + user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=600) valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict.pop("token", None) """ @@ -837,114 +847,127 @@ async def update_database( - Update that user's row - Update litellm-proxy-budget row (global proxy spend) """ - user_ids = [user_id, litellm_proxy_budget_name] - data_list = [] - for id in user_ids: - if id is None: - continue + try: + user_ids = [user_id, litellm_proxy_budget_name] + data_list = [] + for id in user_ids: + if id is None: + continue + if prisma_client is not None: + existing_spend_obj = await prisma_client.get_data(user_id=id) + elif ( + custom_db_client is not None and id != litellm_proxy_budget_name + ): + existing_spend_obj = await custom_db_client.get_data( + key=id, table_name="user" + ) + verbose_proxy_logger.debug( + f"Updating existing_spend_obj: {existing_spend_obj}" + ) + if existing_spend_obj is None: + existing_spend = 0 + existing_spend_obj = LiteLLM_UserTable( + user_id=id, spend=0, max_budget=None, user_email=None + ) + else: + existing_spend = existing_spend_obj.spend + + # Calculate the new cost by adding the existing cost and response_cost + existing_spend_obj.spend = existing_spend + response_cost + + verbose_proxy_logger.debug(f"new cost: {existing_spend_obj.spend}") + data_list.append(existing_spend_obj) + + # Update the cost column for the given user id if prisma_client is not None: - existing_spend_obj = await prisma_client.get_data(user_id=id) - elif custom_db_client is not None and id != litellm_proxy_budget_name: - existing_spend_obj = await custom_db_client.get_data( - key=id, table_name="user" + await prisma_client.update_data( + data_list=data_list, query_type="update_many", table_name="user" ) - verbose_proxy_logger.debug( - f"Updating existing_spend_obj: {existing_spend_obj}" - ) - if existing_spend_obj is None: - existing_spend = 0 - existing_spend_obj = LiteLLM_UserTable( - user_id=id, spend=0, max_budget=None, user_email=None + elif custom_db_client is not None and user_id is not None: + new_spend = data_list[0].spend + await custom_db_client.update_data( + key=user_id, value={"spend": new_spend}, table_name="user" ) - else: - existing_spend = existing_spend_obj.spend - - # Calculate the new cost by adding the existing cost and response_cost - existing_spend_obj.spend = existing_spend + response_cost - - verbose_proxy_logger.debug(f"new cost: {existing_spend_obj.spend}") - data_list.append(existing_spend_obj) - - # Update the cost column for the given user id - if prisma_client is not None: - await prisma_client.update_data( - data_list=data_list, query_type="update_many", table_name="user" - ) - elif custom_db_client is not None and user_id is not None: - new_spend = data_list[0].spend - await custom_db_client.update_data( - key=user_id, value={"spend": new_spend}, table_name="user" - ) + except Exception as e: + verbose_proxy_logger.info(f"Update User DB call failed to execute") ### UPDATE KEY SPEND ### async def _update_key_db(): - verbose_proxy_logger.debug( - f"adding spend to key db. Response cost: {response_cost}. Token: {token}." - ) - if prisma_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await prisma_client.get_data(token=token) + try: verbose_proxy_logger.debug( - f"_update_key_db: existing spend: {existing_spend_obj}" + f"adding spend to key db. Response cost: {response_cost}. Token: {token}." ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost + if prisma_client is not None: + # Fetch the existing cost for the given token + existing_spend_obj = await prisma_client.get_data(token=token) + verbose_proxy_logger.debug( + f"_update_key_db: existing spend: {existing_spend_obj}" + ) + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await prisma_client.update_data(token=token, data={"spend": new_spend}) + verbose_proxy_logger.debug(f"new cost: {new_spend}") + # Update the cost column for the given token + await prisma_client.update_data( + token=token, data={"spend": new_spend} + ) - valid_token = user_api_key_cache.get_cache(key=token) - if valid_token is not None: - valid_token.spend = new_spend - user_api_key_cache.set_cache(key=token, value=valid_token) - elif custom_db_client is not None: - # Fetch the existing cost for the given token - existing_spend_obj = await custom_db_client.get_data( - key=token, table_name="key" - ) - verbose_proxy_logger.debug( - f"_update_key_db existing spend: {existing_spend_obj}" - ) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost + valid_token = user_api_key_cache.get_cache(key=token) + if valid_token is not None: + valid_token.spend = new_spend + user_api_key_cache.set_cache(key=token, value=valid_token) + elif custom_db_client is not None: + # Fetch the existing cost for the given token + existing_spend_obj = await custom_db_client.get_data( + key=token, table_name="key" + ) + verbose_proxy_logger.debug( + f"_update_key_db existing spend: {existing_spend_obj}" + ) + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await custom_db_client.update_data( - key=token, value={"spend": new_spend}, table_name="key" - ) + verbose_proxy_logger.debug(f"new cost: {new_spend}") + # Update the cost column for the given token + await custom_db_client.update_data( + key=token, value={"spend": new_spend}, table_name="key" + ) - valid_token = user_api_key_cache.get_cache(key=token) - if valid_token is not None: - valid_token.spend = new_spend - user_api_key_cache.set_cache(key=token, value=valid_token) + valid_token = user_api_key_cache.get_cache(key=token) + if valid_token is not None: + valid_token.spend = new_spend + user_api_key_cache.set_cache(key=token, value=valid_token) + except Exception as e: + verbose_proxy_logger.info(f"Update Key DB Call failed to execute") ### UPDATE SPEND LOGS ### async def _insert_spend_log_to_db(): - # Helper to generate payload to log - verbose_proxy_logger.debug("inserting spend log to db") - payload = get_logging_payload( - kwargs=kwargs, - response_obj=completion_response, - start_time=start_time, - end_time=end_time, - ) + try: + # Helper to generate payload to log + verbose_proxy_logger.debug("inserting spend log to db") + payload = get_logging_payload( + kwargs=kwargs, + response_obj=completion_response, + start_time=start_time, + end_time=end_time, + ) - payload["spend"] = response_cost - if prisma_client is not None: - await prisma_client.insert_data(data=payload, table_name="spend") + payload["spend"] = response_cost + if prisma_client is not None: + await prisma_client.insert_data(data=payload, table_name="spend") - elif custom_db_client is not None: - await custom_db_client.insert_data(payload, table_name="spend") + elif custom_db_client is not None: + await custom_db_client.insert_data(payload, table_name="spend") + except Exception as e: + verbose_proxy_logger.info(f"Update Spend Logs DB failed to execute") asyncio.create_task(_update_user_db()) asyncio.create_task(_update_key_db()) @@ -1534,7 +1557,7 @@ async def generate_key_helper_fn( user_api_key_cache.set_cache( key=hashed_token, value=LiteLLM_VerificationToken(**saved_token), # type: ignore - ttl=60, + ttl=600, ) if prisma_client is not None: ## CREATE USER (If necessary) @@ -2979,6 +3002,9 @@ async def spend_user_fn( "/spend/logs", tags=["budget & spend Tracking"], dependencies=[Depends(user_api_key_auth)], + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, ) async def view_spend_logs( api_key: Optional[str] = fastapi.Query( @@ -3049,7 +3075,8 @@ async def view_spend_logs( query_type="find_unique", key_val={"key": "request_id", "value": request_id}, ) - return [spend_log] + response = LiteLLM_SpendLogs_ResponseObject(response=[spend_log]) + return response elif user_id is not None: spend_log = await prisma_client.get_data( table_name="spend", @@ -3065,7 +3092,7 @@ async def view_spend_logs( table_name="spend", query_type="find_all" ) - return spend_logs + return spend_log return None diff --git a/tests/test_keys.py b/tests/test_keys.py index 6740308ac5..da7dc19b46 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -490,7 +490,7 @@ async def test_key_crossing_budget(): @pytest.mark.asyncio -async def test_key_zinfo_spend_values_sagemaker(): +async def test_key_info_spend_values_sagemaker(): """ Tests the sync streaming loop to ensure spend is correctly calculated. - create key diff --git a/tests/test_spend_logs.py b/tests/test_spend_logs.py index 1907c4daee..6db0e301e7 100644 --- a/tests/test_spend_logs.py +++ b/tests/test_spend_logs.py @@ -1,7 +1,7 @@ # What this tests? ## Tests /spend endpoints. -import pytest +import pytest, time, uuid import asyncio import aiohttp @@ -26,17 +26,17 @@ async def generate_key(session, models=[]): return await response.json() -async def chat_completion(session, key): +async def chat_completion(session, key, model="gpt-3.5-turbo"): url = "http://0.0.0.0:4000/chat/completions" headers = { "Authorization": f"Bearer {key}", "Content-Type": "application/json", } data = { - "model": "gpt-3.5-turbo", + "model": model, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello!"}, + {"role": "user", "content": f"Hello! {uuid.uuid4()}"}, ], } @@ -53,8 +53,37 @@ async def chat_completion(session, key): return await response.json() -async def get_spend_logs(session, request_id): - url = f"http://0.0.0.0:4000/spend/logs?request_id={request_id}" +async def chat_completion_high_traffic(session, key, model="gpt-3.5-turbo"): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": model, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": f"Hello! {uuid.uuid4()}"}, + ], + } + try: + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + return await response.json() + except Exception as e: + return None + + +async def get_spend_logs(session, request_id=None, api_key=None): + if api_key is not None: + url = f"http://0.0.0.0:4000/spend/logs?api_key={api_key}" + else: + url = f"http://0.0.0.0:4000/spend/logs?request_id={request_id}" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} async with session.get(url, headers=headers) as response: @@ -82,3 +111,53 @@ async def test_spend_logs(): response = await chat_completion(session=session, key=key) await asyncio.sleep(5) await get_spend_logs(session=session, request_id=response["id"]) + + +@pytest.mark.asyncio +async def test_spend_logs_high_traffic(): + """ + - Create key + - Make 30 concurrent calls + - Get all logs for that key + - Wait 10s + - Assert it's 30 + """ + + async def retry_request(func, *args, _max_attempts=5, **kwargs): + for attempt in range(_max_attempts): + try: + return await func(*args, **kwargs) + except ( + aiohttp.client_exceptions.ClientOSError, + aiohttp.client_exceptions.ServerDisconnectedError, + ) as e: + if attempt + 1 == _max_attempts: + raise # re-raise the last ClientOSError if all attempts failed + print(f"Attempt {attempt+1} failed, retrying...") + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=600) + ) as session: + start = time.time() + key_gen = await generate_key(session=session) + key = key_gen["key"] + n = 1000 + tasks = [ + retry_request( + chat_completion_high_traffic, + session=session, + key=key, + model="azure-gpt-3.5", + ) + for _ in range(n) + ] + chat_completions = await asyncio.gather(*tasks) + successful_completions = [c for c in chat_completions if c is not None] + print(f"Num successful completions: {len(successful_completions)}") + await asyncio.sleep(10) + response = await get_spend_logs(session=session, api_key=key) + print(f"response: {response}") + print(f"len responses: {len(response)}") + assert len(response) == n + print(n, time.time() - start, len(response)) + raise Exception("it worked!") From 4a0df3cb4fe648a6e77924a420c85fccc5360e6a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 6 Feb 2024 19:39:49 -0800 Subject: [PATCH 2/4] fix(proxy_cli.py-&&-proxy_server.py): bump reset budget intervals and fix pool limits for prisma connections --- litellm/proxy/proxy_cli.py | 17 ++++++++++++++++- litellm/proxy/proxy_server.py | 6 +++++- litellm/proxy/utils.py | 8 +------- tests/test_keys.py | 9 +++++---- tests/test_spend_logs.py | 14 +++++++++----- 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 890cf52941..4982900948 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -5,6 +5,7 @@ import random from datetime import datetime import importlib from dotenv import load_dotenv +import urllib.parse as urlparse sys.path.append(os.getcwd()) @@ -17,6 +18,15 @@ import shutil telemetry = None +def append_query_params(url, params): + parsed_url = urlparse.urlparse(url) + parsed_query = urlparse.parse_qs(parsed_url.query) + parsed_query.update(params) + encoded_query = urlparse.urlencode(parsed_query, doseq=True) + modified_url = urlparse.urlunparse(parsed_url._replace(query=encoded_query)) + return modified_url + + def run_ollama_serve(): try: command = ["ollama", "serve"] @@ -416,6 +426,12 @@ def run_server( if os.getenv("DATABASE_URL", None) is not None: try: + ### add connection pool + pool timeout args + params = {"connection_limit": 500, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + ### subprocess.run(["prisma"], capture_output=True) is_prisma_runnable = True except FileNotFoundError: @@ -522,6 +538,5 @@ def run_server( ).run() # Run gunicorn - if __name__ == "__main__": run_server() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5e016c46be..ca1c26f7c9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -841,6 +841,10 @@ async def update_database( f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}" ) + ### [TODO] STEP 1: GET KEY + USER SPEND ### (key, user) + + ### [TODO] STEP 2: UPDATE SPEND ### (key, user, spend logs) + ### UPDATE USER SPEND ### async def _update_user_db(): """ @@ -1922,7 +1926,7 @@ async def startup_event(): ### START BUDGET SCHEDULER ### scheduler = AsyncIOScheduler() interval = random.randint( - 7, 14 + 597, 605 ) # random interval, so multiple workers avoid resetting budget at the same time scheduler.add_job(reset_budget, "interval", seconds=interval, args=[prisma_client]) scheduler.start() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 84b09d7265..178403e44b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -391,13 +391,7 @@ class PrismaClient: # Now you can import the Prisma Client from prisma import Prisma # type: ignore - self.db = Prisma( - http={ - "limits": httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ) - } - ) # Client to connect to Prisma db + self.db = Prisma() # Client to connect to Prisma db def hash_token(self, token: str): # Hash the string using SHA-256 diff --git a/tests/test_keys.py b/tests/test_keys.py index da7dc19b46..28ce025112 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -431,9 +431,9 @@ async def test_key_info_spend_values_image_generation(): @pytest.mark.asyncio async def test_key_with_budgets(): """ - - Create key with budget and 5s duration + - Create key with budget and 5min duration - Get 'reset_at' value - - wait 5s + - wait 10min (budget reset runs every 10mins.) - Check if value updated """ from litellm.proxy.utils import hash_token @@ -449,8 +449,8 @@ async def test_key_with_budgets(): reset_at_init_value = key_info["info"]["budget_reset_at"] reset_at_new_value = None i = 0 + await asyncio.sleep(610) while i < 3: - await asyncio.sleep(30) key_info = await get_key_info(session=session, get_key=key, call_key=key) reset_at_new_value = key_info["info"]["budget_reset_at"] try: @@ -458,6 +458,7 @@ async def test_key_with_budgets(): break except: i + 1 + await asyncio.sleep(5) assert reset_at_init_value != reset_at_new_value @@ -481,7 +482,7 @@ async def test_key_crossing_budget(): response = await chat_completion(session=session, key=key) print("response 1: ", response) - await asyncio.sleep(2) + await asyncio.sleep(10) try: response = await chat_completion(session=session, key=key) pytest.fail("Should have failed - Key crossed it's budget") diff --git a/tests/test_spend_logs.py b/tests/test_spend_logs.py index 6db0e301e7..4d7ad175f9 100644 --- a/tests/test_spend_logs.py +++ b/tests/test_spend_logs.py @@ -113,6 +113,7 @@ async def test_spend_logs(): await get_spend_logs(session=session, request_id=response["id"]) +@pytest.mark.skip(reason="High traffic load test, meant to be run locally") @pytest.mark.asyncio async def test_spend_logs_high_traffic(): """ @@ -155,9 +156,12 @@ async def test_spend_logs_high_traffic(): successful_completions = [c for c in chat_completions if c is not None] print(f"Num successful completions: {len(successful_completions)}") await asyncio.sleep(10) - response = await get_spend_logs(session=session, api_key=key) - print(f"response: {response}") - print(f"len responses: {len(response)}") - assert len(response) == n - print(n, time.time() - start, len(response)) + try: + response = await retry_request(get_spend_logs, session=session, api_key=key) + print(f"response: {response}") + print(f"len responses: {len(response)}") + assert len(response) == n + print(n, time.time() - start, len(response)) + except: + print(n, time.time() - start, 0) raise Exception("it worked!") From 4174471dacfca64ec7c3e606ecb557d0781b609f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 6 Feb 2024 22:09:30 -0800 Subject: [PATCH 3/4] fix(proxy_server.py): fix endpoint --- litellm/proxy/proxy_server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ca1c26f7c9..5eb20e22ef 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3079,8 +3079,7 @@ async def view_spend_logs( query_type="find_unique", key_val={"key": "request_id", "value": request_id}, ) - response = LiteLLM_SpendLogs_ResponseObject(response=[spend_log]) - return response + return [spend_log] elif user_id is not None: spend_log = await prisma_client.get_data( table_name="spend", From fd9c7a90af48bd9db3db8b10a267e4671d3bf96c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 6 Feb 2024 23:06:05 -0800 Subject: [PATCH 4/4] fix(proxy_server.py): update user cache to with new spend --- litellm/_logging.py | 2 +- litellm/proxy/proxy_server.py | 29 ++++++++++++++++++++++------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/litellm/_logging.py b/litellm/_logging.py index 171761c3c2..438fa9743d 100644 --- a/litellm/_logging.py +++ b/litellm/_logging.py @@ -3,7 +3,7 @@ import logging set_verbose = False # Create a handler for the logger (you may need to adapt this based on your needs) -handler = logging.FileHandler("log_file.txt") +handler = logging.StreamHandler() handler.setLevel(logging.DEBUG) # Create a formatter and set it for the handler diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5eb20e22ef..6b89b0f1b5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -432,10 +432,18 @@ async def user_api_key_auth( # Check 2. If user_id for this token is in budget ## Check 2.5 If global proxy is in budget if valid_token.user_id is not None: - user_id_information = user_api_key_cache.get_cache( - key=valid_token.user_id - ) - if user_id_information is None: + user_id_list = [valid_token.user_id, litellm_proxy_budget_name] + user_id_information = None + for id in user_id_list: + value = user_api_key_cache.get_cache(key=id) + if value is not None: + if user_id_information is None: + user_id_information = [] + user_id_information.append(value) + if user_id_information is None or ( + isinstance(user_id_information, list) + and len(user_id_information) < 2 + ): if prisma_client is not None: user_id_information = await prisma_client.get_data( user_id_list=[ @@ -445,13 +453,14 @@ async def user_api_key_auth( table_name="user", query_type="find_all", ) + for _id in user_id_information: + user_api_key_cache.set_cache( + key=_id["user_id"], value=_id, ttl=600 + ) if custom_db_client is not None: user_id_information = await custom_db_client.get_data( key=valid_token.user_id, table_name="user" ) - user_api_key_cache.set_cache( - key=valid_token.user_id, value=user_id_information, ttl=600 - ) verbose_proxy_logger.debug( f"user_id_information: {user_id_information}" @@ -879,6 +888,12 @@ async def update_database( # Calculate the new cost by adding the existing cost and response_cost existing_spend_obj.spend = existing_spend + response_cost + valid_token = user_api_key_cache.get_cache(key=id) + if valid_token is not None and isinstance(valid_token, dict): + user_api_key_cache.set_cache( + key=id, value=existing_spend_obj.json() + ) + verbose_proxy_logger.debug(f"new cost: {existing_spend_obj.spend}") data_list.append(existing_spend_obj)