diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8ddadd6b14..ec41ef20aa 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,6 +1,6 @@ from typing import Optional, List, Any, Literal import os, subprocess, hashlib, importlib, asyncio -import litellm +import litellm, backoff ### LOGGING ### class ProxyLogging: @@ -65,6 +65,12 @@ class ProxyLogging: ### DB CONNECTOR ### +# Define the retry decorator with backoff strategy +# Function to be called whenever a retry is about to happen +def on_backoff(details): + # The 'tries' key in the details dictionary contains the number of completed tries + print(f"Backing off... this was attempt #{details['tries']}") + class PrismaClient: def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") @@ -96,6 +102,13 @@ class PrismaClient: return hashed_token + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) async def get_data(self, token: str, expires: Optional[Any]=None): try: hashed_token = self.hash_token(token=token) @@ -116,6 +129,14 @@ class PrismaClient: except Exception as e: asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) + # Define a retrying strategy with exponential backoff + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) async def insert_data(self, data: dict): """ Add a key to the database. If it already exists, do nothing. @@ -138,7 +159,16 @@ class PrismaClient: return new_verification_token except Exception as e: asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) + raise e + # Define a retrying strategy with exponential backoff + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) async def update_data(self, token: str, data: dict): """ Update existing data @@ -165,6 +195,14 @@ class PrismaClient: print() + # Define a retrying strategy with exponential backoff + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) async def delete_data(self, tokens: List): """ Allow user to delete a key(s) diff --git a/litellm/tests/test_proxy_server_spend.py b/litellm/tests/test_proxy_server_spend.py index f64ad89877..5569b67319 100644 --- a/litellm/tests/test_proxy_server_spend.py +++ b/litellm/tests/test_proxy_server_spend.py @@ -1,82 +1,82 @@ -# import openai, json, time, asyncio -# client = openai.AsyncOpenAI( -# api_key="sk-1234", -# base_url="http://0.0.0.0:8000" -# ) +import openai, json, time, asyncio +client = openai.AsyncOpenAI( + api_key="sk-1234", + base_url="http://0.0.0.0:8000" +) -# super_fake_messages = [ -# { -# "role": "user", -# "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}" -# }, -# { -# "content": None, -# "role": "assistant", -# "tool_calls": [ -# { -# "id": "1", -# "function": { -# "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}", -# "name": "get_current_weather" -# }, -# "type": "function" -# }, -# { -# "id": "2", -# "function": { -# "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}", -# "name": "get_current_weather" -# }, -# "type": "function" -# }, -# { -# "id": "3", -# "function": { -# "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}", -# "name": "get_current_weather" -# }, -# "type": "function" -# } -# ] -# }, -# { -# "tool_call_id": "1", -# "role": "tool", -# "name": "get_current_weather", -# "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}" -# }, -# { -# "tool_call_id": "2", -# "role": "tool", -# "name": "get_current_weather", -# "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}" -# }, -# { -# "tool_call_id": "3", -# "role": "tool", -# "name": "get_current_weather", -# "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}" -# } -# ] +super_fake_messages = [ + { + "role": "user", + "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}" + }, + { + "content": None, + "role": "assistant", + "tool_calls": [ + { + "id": "1", + "function": { + "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}", + "name": "get_current_weather" + }, + "type": "function" + }, + { + "id": "2", + "function": { + "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}", + "name": "get_current_weather" + }, + "type": "function" + }, + { + "id": "3", + "function": { + "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}", + "name": "get_current_weather" + }, + "type": "function" + } + ] + }, + { + "tool_call_id": "1", + "role": "tool", + "name": "get_current_weather", + "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}" + }, + { + "tool_call_id": "2", + "role": "tool", + "name": "get_current_weather", + "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}" + }, + { + "tool_call_id": "3", + "role": "tool", + "name": "get_current_weather", + "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}" + } +] -# async def chat_completions(): -# super_fake_response = await client.chat.completions.create( -# model="gpt-3.5-turbo", -# messages=super_fake_messages, -# seed=1337, -# stream=False -# ) # get a new response from the model where it can see the function response -# await asyncio.sleep(1) -# return super_fake_response +async def chat_completions(): + super_fake_response = await client.chat.completions.create( + model="gpt-3.5-turbo", + messages=super_fake_messages, + seed=1337, + stream=False + ) # get a new response from the model where it can see the function response + await asyncio.sleep(1) + return super_fake_response -# async def loadtest_fn(n = 1): -# global num_task_cancelled_errors, exception_counts, chat_completions -# start = time.time() -# tasks = [chat_completions() for _ in range(n)] -# chat_completions = await asyncio.gather(*tasks) -# successful_completions = [c for c in chat_completions if c is not None] -# print(n, time.time() - start, len(successful_completions)) +async def loadtest_fn(n = 1): + global num_task_cancelled_errors, exception_counts, chat_completions + start = time.time() + tasks = [chat_completions() for _ in range(n)] + chat_completions = await asyncio.gather(*tasks) + successful_completions = [c for c in chat_completions if c is not None] + print(n, time.time() - start, len(successful_completions)) -# # print(json.dumps(super_fake_response.model_dump(), indent=4)) +# print(json.dumps(super_fake_response.model_dump(), indent=4)) -# asyncio.run(loadtest_fn()) \ No newline at end of file +asyncio.run(loadtest_fn()) \ No newline at end of file