fix(proxy/utils.py): add backoff/retry logic to db read/writes

This commit is contained in:
Krrish Dholakia 2023-12-08 13:34:19 -08:00
parent 4566210a07
commit 04fc583baf
2 changed files with 116 additions and 78 deletions

View file

@ -1,6 +1,6 @@
from typing import Optional, List, Any, Literal from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio import os, subprocess, hashlib, importlib, asyncio
import litellm import litellm, backoff
### LOGGING ### ### LOGGING ###
class ProxyLogging: class ProxyLogging:
@ -65,6 +65,12 @@ class ProxyLogging:
### DB CONNECTOR ### ### DB CONNECTOR ###
# Define the retry decorator with backoff strategy
# Function to be called whenever a retry is about to happen
def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries
print(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient: class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
@ -96,6 +102,13 @@ class PrismaClient:
return hashed_token return hashed_token
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def get_data(self, token: str, expires: Optional[Any]=None): async def get_data(self, token: str, expires: Optional[Any]=None):
try: try:
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
@ -116,6 +129,14 @@ class PrismaClient:
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(self, data: dict): async def insert_data(self, data: dict):
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
@ -138,7 +159,16 @@ class PrismaClient:
return new_verification_token return new_verification_token
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def update_data(self, token: str, data: dict): async def update_data(self, token: str, data: dict):
""" """
Update existing data Update existing data
@ -165,6 +195,14 @@ class PrismaClient:
print() print()
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def delete_data(self, tokens: List): async def delete_data(self, tokens: List):
""" """
Allow user to delete a key(s) Allow user to delete a key(s)

View file

@ -1,82 +1,82 @@
# import openai, json, time, asyncio import openai, json, time, asyncio
# client = openai.AsyncOpenAI( client = openai.AsyncOpenAI(
# api_key="sk-1234", api_key="sk-1234",
# base_url="http://0.0.0.0:8000" base_url="http://0.0.0.0:8000"
# ) )
# super_fake_messages = [ super_fake_messages = [
# { {
# "role": "user", "role": "user",
# "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}" "content": f"What's the weather like in San Francisco, Tokyo, and Paris? {time.time()}"
# }, },
# { {
# "content": None, "content": None,
# "role": "assistant", "role": "assistant",
# "tool_calls": [ "tool_calls": [
# { {
# "id": "1", "id": "1",
# "function": { "function": {
# "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}", "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}",
# "name": "get_current_weather" "name": "get_current_weather"
# }, },
# "type": "function" "type": "function"
# }, },
# { {
# "id": "2", "id": "2",
# "function": { "function": {
# "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}", "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}",
# "name": "get_current_weather" "name": "get_current_weather"
# }, },
# "type": "function" "type": "function"
# }, },
# { {
# "id": "3", "id": "3",
# "function": { "function": {
# "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}", "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}",
# "name": "get_current_weather" "name": "get_current_weather"
# }, },
# "type": "function" "type": "function"
# } }
# ] ]
# }, },
# { {
# "tool_call_id": "1", "tool_call_id": "1",
# "role": "tool", "role": "tool",
# "name": "get_current_weather", "name": "get_current_weather",
# "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}" "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}"
# }, },
# { {
# "tool_call_id": "2", "tool_call_id": "2",
# "role": "tool", "role": "tool",
# "name": "get_current_weather", "name": "get_current_weather",
# "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}" "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}"
# }, },
# { {
# "tool_call_id": "3", "tool_call_id": "3",
# "role": "tool", "role": "tool",
# "name": "get_current_weather", "name": "get_current_weather",
# "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}" "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}"
# } }
# ] ]
# async def chat_completions(): async def chat_completions():
# super_fake_response = await client.chat.completions.create( super_fake_response = await client.chat.completions.create(
# model="gpt-3.5-turbo", model="gpt-3.5-turbo",
# messages=super_fake_messages, messages=super_fake_messages,
# seed=1337, seed=1337,
# stream=False stream=False
# ) # get a new response from the model where it can see the function response ) # get a new response from the model where it can see the function response
# await asyncio.sleep(1) await asyncio.sleep(1)
# return super_fake_response return super_fake_response
# async def loadtest_fn(n = 1): async def loadtest_fn(n = 1):
# global num_task_cancelled_errors, exception_counts, chat_completions global num_task_cancelled_errors, exception_counts, chat_completions
# start = time.time() start = time.time()
# tasks = [chat_completions() for _ in range(n)] tasks = [chat_completions() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks) chat_completions = await asyncio.gather(*tasks)
# successful_completions = [c for c in chat_completions if c is not None] successful_completions = [c for c in chat_completions if c is not None]
# print(n, time.time() - start, len(successful_completions)) print(n, time.time() - start, len(successful_completions))
# # print(json.dumps(super_fake_response.model_dump(), indent=4)) # print(json.dumps(super_fake_response.model_dump(), indent=4))
# asyncio.run(loadtest_fn()) asyncio.run(loadtest_fn())