adding tenacity retries

This commit is contained in:
Krrish Dholakia 2023-08-03 10:06:31 -07:00
parent 10832be1e4
commit a8b3fc6c2d
12 changed files with 37 additions and 11 deletions

View file

@ -3,6 +3,7 @@ failure_callback = []
set_verbose=False set_verbose=False
telemetry=True telemetry=True
max_tokens = 256 # OpenAI Defaults max_tokens = 256 # OpenAI Defaults
retry = True # control tenacity retries.
####### PROXY PARAMS ################### configurable params if you use proxy models like Helicone ####### PROXY PARAMS ################### configurable params if you use proxy models like Helicone
api_base = None api_base = None
headers = None headers = None

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -7,6 +7,11 @@ import traceback
import litellm import litellm
from litellm import client, logging, exception_type, timeout, success_callback, failure_callback from litellm import client, logging, exception_type, timeout, success_callback, failure_callback
import random import random
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
) # for exponential backoff
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
@ -55,6 +60,7 @@ def get_optional_params(
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
############################################# #############################################
@client @client
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False`
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` @timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def completion( def completion(
model, messages, # required params model, messages, # required params

View file

@ -15,13 +15,14 @@ def logger_fn(model_call_object: dict):
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{ "content": user_message,"role": "user"}] messages = [{ "content": user_message,"role": "user"}]
print(os.environ)
temp_key = os.environ.get("OPENAI_API_KEY") temp_key = os.environ.get("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = "bad-key" os.environ["OPENAI_API_KEY"] = "bad-key"
# test on openai completion call # test on openai completion call
try: try:
response = completion(model="gpt-3.5-turbo", messages=messages, logger_fn=logger_fn, api_key=temp_key) response = completion(model="gpt-3.5-turbo", messages=messages, logger_fn=logger_fn, api_key=temp_key)
print(f"response: {response}")
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
os.environ["OPENAI_API_KEY"] = temp_key os.environ["OPENAI_API_KEY"] = temp_key

View file

@ -4,7 +4,8 @@
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv
load_dotenv()
# Get the current directory of the script # Get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
@ -19,7 +20,7 @@ import litellm
from litellm import embedding, completion from litellm import embedding, completion
litellm.set_verbose = True
litellm.success_callback = ["posthog"] litellm.success_callback = ["posthog"]
litellm.failure_callback = ["slack", "sentry", "posthog"] litellm.failure_callback = ["slack", "sentry", "posthog"]
@ -36,3 +37,16 @@ def test_completion_with_empty_model():
except Exception as e: except Exception as e:
print(f"error occurred: {e}") print(f"error occurred: {e}")
pass pass
#bad key
temp_key = os.environ.get("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = "bad-key"
# test on openai completion call
try:
response = completion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
os.environ["OPENAI_API_KEY"] = temp_key

View file

@ -18,7 +18,8 @@ start_time = time.time()
try: try:
stop_after_10_s(force_timeout=1) stop_after_10_s(force_timeout=1)
except: except Exception as e:
print(e)
pass pass
end_time = time.time() end_time = time.time()

View file

@ -37,26 +37,29 @@ def timeout(
thread = _LoopWrapper() thread = _LoopWrapper()
thread.start() thread.start()
future = asyncio.run_coroutine_threadsafe(async_func(), thread.loop) future = asyncio.run_coroutine_threadsafe(async_func(), thread.loop)
try:
local_timeout_duration = timeout_duration local_timeout_duration = timeout_duration
if "force_timeout" in kwargs: if "force_timeout" in kwargs:
local_timeout_duration = kwargs["force_timeout"] local_timeout_duration = kwargs["force_timeout"]
try:
result = future.result(timeout=local_timeout_duration) result = future.result(timeout=local_timeout_duration)
except futures.TimeoutError: except futures.TimeoutError:
thread.stop_loop() thread.stop_loop()
raise exception_to_raise() raise exception_to_raise(f"A timeout error occurred. The function call took longer than {local_timeout_duration} second(s).")
thread.stop_loop() thread.stop_loop()
return result return result
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
local_timeout_duration = timeout_duration
if "force_timeout" in kwargs:
local_timeout_duration = kwargs["force_timeout"]
try: try:
value = await asyncio.wait_for( value = await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_duration func(*args, **kwargs), timeout=timeout_duration
) )
return value return value
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise exception_to_raise() raise exception_to_raise(f"A timeout error occurred. The function call took longer than {local_timeout_duration} second(s).")
if iscoroutinefunction(func): if iscoroutinefunction(func):
return async_wrapper return async_wrapper

View file

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name='litellm', name='litellm',
version='0.1.226', version='0.1.227',
description='Library to easily interface with LLM API providers', description='Library to easily interface with LLM API providers',
author='BerriAI', author='BerriAI',
packages=[ packages=[