custom timeout decorator

This commit is contained in:
Krrish Dholakia 2023-08-01 12:20:25 -07:00
parent 79847145f8
commit 7b2901be9e
10 changed files with 121 additions and 16 deletions

View file

@ -25,6 +25,7 @@ open_ai_embedding_models = [
'text-embedding-ada-002'
]
from .timeout import timeout
from .utils import client, logging, exception_type # Import all the symbols from main.py
from .main import * # Import all the symbols from main.py

Binary file not shown.

View file

@ -6,8 +6,7 @@ import traceback
import dotenv
import traceback
import litellm
from litellm import client, logging, exception_type
from litellm import success_callback, failure_callback
from litellm import client, logging, exception_type, timeout, success_callback, failure_callback
import random
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
@ -59,7 +58,7 @@ def get_optional_params(
####### COMPLETION ENDPOINTS ################
#############################################
@client
@func_set_timeout(180, allowOverride=True) ## https://pypi.org/project/func-timeout/ - timeouts, in case calls hang (e.g. Azure)
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure)
def completion(
model, messages, # required params
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
@ -67,7 +66,7 @@ def completion(
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="",
# Optional liteLLM function params
*, forceTimeout=60, azure=False, logger_fn=None, verbose=False
*, force_timeout=60, azure=False, logger_fn=None, verbose=False
):
try:
# check if user passed in any of the OpenAI optional params

View file

@ -1,11 +1,3 @@
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, OpenAIError
import os
import sys
import traceback
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import litellm
from litellm import embedding, completion
from concurrent.futures import ThreadPoolExecutor
#### What this tests ####
# This tests exception mapping -> trigger an exception from an llm provider -> assert if output is of the expected type
@ -16,6 +8,15 @@ from concurrent.futures import ThreadPoolExecutor
# Approach: Run each model through the test -> assert if the correct error (always the same one) is triggered
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, OpenAIError
import os
import sys
import traceback
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import litellm
from litellm import embedding, completion
from concurrent.futures import ThreadPoolExecutor
models = ["gpt-3.5-turbo", "chatgpt-test", "claude-instant-1", "command-nightly", "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"]
# Test 1: Rate Limit Errors

View file

@ -0,0 +1,26 @@
#### What this tests ####
# This tests the timeout decorator
import sys, os
import traceback
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import time
from litellm import timeout
@timeout(10)
def stop_after_10_s(force_timeout=60):
print("Stopping after 10 seconds")
time.sleep(10)
return
start_time = time.time()
try:
stop_after_10_s(force_timeout=1)
except:
pass
end_time = time.time()
print(f"total time: {end_time-start_time}")

80
litellm/timeout.py Normal file
View file

@ -0,0 +1,80 @@
"""
Module containing "timeout" decorator for sync and async callables.
"""
import asyncio
from concurrent import futures
from inspect import iscoroutinefunction
from functools import wraps
from threading import Thread
from openai.error import Timeout
def timeout(
timeout_duration: float = None, exception_to_raise = Timeout
):
"""
Wraps a function to raise the specified exception if execution time
is greater than the specified timeout.
Works with both synchronous and asynchronous callables, but with synchronous ones will introduce
some overhead due to the backend use of threads and asyncio.
:param float timeout_duration: Timeout duration in seconds. If none callable won't time out.
:param OpenAIError exception_to_raise: Exception to raise when the callable times out.
Defaults to TimeoutError.
:return: The decorated function.
:rtype: callable
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
async def async_func():
return func(*args, **kwargs)
thread = _LoopWrapper()
thread.start()
future = asyncio.run_coroutine_threadsafe(async_func(), thread.loop)
try:
local_timeout_duration = timeout_duration
if "force_timeout" in kwargs:
local_timeout_duration = kwargs["force_timeout"]
result = future.result(timeout=local_timeout_duration)
except futures.TimeoutError:
thread.stop_loop()
raise exception_to_raise()
thread.stop_loop()
return result
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
value = await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_duration
)
return value
except asyncio.TimeoutError:
raise exception_to_raise()
if iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator
class _LoopWrapper(Thread):
def __init__(self):
super().__init__(daemon=True)
self.loop = asyncio.new_event_loop()
def run(self) -> None:
self.loop.run_forever()
self.loop.call_soon_threadsafe(self.loop.close)
def stop_loop(self):
for task in asyncio.all_tasks(self.loop):
task.cancel()
self.loop.call_soon_threadsafe(self.loop.stop)

View file

@ -1,6 +1,5 @@
openai
cohere
func_timeout
anthropic
replicate
pytest

View file

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name='litellm',
version='0.1.206',
version='0.1.207',
description='Library to easily interface with LLM API providers',
author='BerriAI',
packages=[
@ -11,11 +11,10 @@ setup(
install_requires=[
'openai',
'cohere',
'func_timeout',
'pytest',
'anthropic',
'replicate',
'python-dotenv',
'openai[datalib]'
'openai[datalib]',
],
)