Merge pull request #32 from BerriAI/set-timeouts

custom timeout decorator
This commit is contained in:
Krish Dholakia 2023-08-01 14:45:03 -07:00 committed by GitHub
commit bb49f1cdba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 247 additions and 129 deletions

View file

@ -1,6 +1,6 @@
Metadata-Version: 2.1 Metadata-Version: 2.1
Name: litellm Name: litellm
Version: 0.1.2 Version: 0.1.207
Summary: Library to easily interface with LLM API providers Summary: Library to easily interface with LLM API providers
Author: BerriAI Author: BerriAI
License-File: LICENSE License-File: LICENSE

View file

@ -3,6 +3,8 @@ README.md
setup.py setup.py
litellm/__init__.py litellm/__init__.py
litellm/main.py litellm/main.py
litellm/timeout.py
litellm/utils.py
litellm.egg-info/PKG-INFO litellm.egg-info/PKG-INFO
litellm.egg-info/SOURCES.txt litellm.egg-info/SOURCES.txt
litellm.egg-info/dependency_links.txt litellm.egg-info/dependency_links.txt

View file

@ -1,2 +1,7 @@
openai openai
cohere cohere
pytest
anthropic
replicate
python-dotenv
openai[datalib]

View file

@ -25,6 +25,7 @@ open_ai_embedding_models = [
'text-embedding-ada-002' 'text-embedding-ada-002'
] ]
from .timeout import timeout
from .utils import client, logging, exception_type # Import all the symbols from main.py from .utils import client, logging, exception_type # Import all the symbols from main.py
from .main import * # Import all the symbols from main.py from .main import * # Import all the symbols from main.py

Binary file not shown.

View file

@ -1,19 +1,15 @@
import os, openai, cohere, replicate, sys import os, openai, cohere, replicate, sys
from typing import Any from typing import Any
from func_timeout import func_set_timeout, FunctionTimedOut
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
import traceback import traceback
import dotenv import dotenv
import traceback import traceback
import litellm import litellm
from litellm import client, logging, exception_type from litellm import client, logging, exception_type, timeout, success_callback, failure_callback
from litellm import success_callback, failure_callback
import random import random
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
def get_optional_params( def get_optional_params(
# 12 optional params # 12 optional params
functions = [], functions = [],
@ -59,7 +55,7 @@ def get_optional_params(
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
############################################# #############################################
@client @client
@func_set_timeout(180, allowOverride=True) ## https://pypi.org/project/func-timeout/ - timeouts, in case calls hang (e.g. Azure) @timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def completion( def completion(
model, messages, # required params model, messages, # required params
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
@ -67,7 +63,7 @@ def completion(
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'), temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", presence_penalty=0, frequency_penalty=0, logit_bias={}, user="",
# Optional liteLLM function params # Optional liteLLM function params
*, forceTimeout=60, azure=False, logger_fn=None, verbose=False *, force_timeout=60, azure=False, logger_fn=None, verbose=False
): ):
try: try:
# check if user passed in any of the OpenAI optional params # check if user passed in any of the OpenAI optional params
@ -254,8 +250,8 @@ def completion(
### EMBEDDING ENDPOINTS #################### ### EMBEDDING ENDPOINTS ####################
@client @client
@func_set_timeout(60, allowOverride=True) ## https://pypi.org/project/func-timeout/ @timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def embedding(model, input=[], azure=False, forceTimeout=60, logger_fn=None): def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None):
response = None response = None
if azure == True: if azure == True:
# azure configs # azure configs

View file

@ -1,3 +1,13 @@
#### What this tests ####
# This tests exception mapping -> trigger an exception from an llm provider -> assert if output is of the expected type
# # 5 providers -> OpenAI, Azure, Anthropic, Cohere, Replicate
# # 3 main types of exceptions -> - Rate Limit Errors, Context Window Errors, Auth errors (incorrect/rotated key, etc.)
# # Approach: Run each model through the test -> assert if the correct error (always the same one) is triggered
# from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, OpenAIError # from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, OpenAIError
# import os # import os
# import sys # import sys
@ -6,15 +16,6 @@
# import litellm # import litellm
# from litellm import embedding, completion # from litellm import embedding, completion
# from concurrent.futures import ThreadPoolExecutor # from concurrent.futures import ThreadPoolExecutor
# #### What this tests ####
# # This tests exception mapping -> trigger an exception from an llm provider -> assert if output is of the expected type
# # 5 providers -> OpenAI, Azure, Anthropic, Cohere, Replicate
# # 3 main types of exceptions -> - Rate Limit Errors, Context Window Errors, Auth errors (incorrect/rotated key, etc.)
# # Approach: Run each model through the test -> assert if the correct error (always the same one) is triggered
# models = ["gpt-3.5-turbo", "chatgpt-test", "claude-instant-1", "command-nightly", "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"] # models = ["gpt-3.5-turbo", "chatgpt-test", "claude-instant-1", "command-nightly", "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"]

View file

@ -0,0 +1,26 @@
#### What this tests ####
# This tests the timeout decorator
import sys, os
import traceback
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import time
from litellm import timeout
@timeout(10)
def stop_after_10_s(force_timeout=60):
print("Stopping after 10 seconds")
time.sleep(10)
return
start_time = time.time()
try:
stop_after_10_s(force_timeout=1)
except:
pass
end_time = time.time()
print(f"total time: {end_time-start_time}")

80
litellm/timeout.py Normal file
View file

@ -0,0 +1,80 @@
"""
Module containing "timeout" decorator for sync and async callables.
"""
import asyncio
from concurrent import futures
from inspect import iscoroutinefunction
from functools import wraps
from threading import Thread
from openai.error import Timeout
def timeout(
timeout_duration: float = None, exception_to_raise = Timeout
):
"""
Wraps a function to raise the specified exception if execution time
is greater than the specified timeout.
Works with both synchronous and asynchronous callables, but with synchronous ones will introduce
some overhead due to the backend use of threads and asyncio.
:param float timeout_duration: Timeout duration in seconds. If none callable won't time out.
:param OpenAIError exception_to_raise: Exception to raise when the callable times out.
Defaults to TimeoutError.
:return: The decorated function.
:rtype: callable
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
async def async_func():
return func(*args, **kwargs)
thread = _LoopWrapper()
thread.start()
future = asyncio.run_coroutine_threadsafe(async_func(), thread.loop)
try:
local_timeout_duration = timeout_duration
if "force_timeout" in kwargs:
local_timeout_duration = kwargs["force_timeout"]
result = future.result(timeout=local_timeout_duration)
except futures.TimeoutError:
thread.stop_loop()
raise exception_to_raise()
thread.stop_loop()
return result
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
value = await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_duration
)
return value
except asyncio.TimeoutError:
raise exception_to_raise()
if iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator
class _LoopWrapper(Thread):
def __init__(self):
super().__init__(daemon=True)
self.loop = asyncio.new_event_loop()
def run(self) -> None:
self.loop.run_forever()
self.loop.call_soon_threadsafe(self.loop.close)
def stop_loop(self):
for task in asyncio.all_tasks(self.loop):
task.cancel()
self.loop.call_soon_threadsafe(self.loop.stop)

View file

@ -101,6 +101,7 @@ def client(original_function):
####### HELPER FUNCTIONS ################ ####### HELPER FUNCTIONS ################
def set_callbacks(callback_list): def set_callbacks(callback_list):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel
try:
for callback in callback_list: for callback in callback_list:
if callback == "sentry": if callback == "sentry":
try: try:
@ -136,10 +137,13 @@ def set_callbacks(callback_list):
) )
alerts_channel = os.environ["SLACK_API_CHANNEL"] alerts_channel = os.environ["SLACK_API_CHANNEL"]
print_verbose(f"Initialized Slack App: {slack_app}") print_verbose(f"Initialized Slack App: {slack_app}")
except:
pass
def handle_failure(exception, traceback_exception, args, kwargs): def handle_failure(exception, traceback_exception, args, kwargs):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel
try:
print_verbose(f"handle_failure args: {args}") print_verbose(f"handle_failure args: {args}")
print_verbose(f"handle_failure kwargs: {kwargs}") print_verbose(f"handle_failure kwargs: {kwargs}")
@ -196,8 +200,11 @@ def handle_failure(exception, traceback_exception, args, kwargs):
} }
failure_handler(call_details) failure_handler(call_details)
pass pass
except:
pass
def handle_success(*args, **kwargs): def handle_success(*args, **kwargs):
try:
success_handler = additional_details.pop("success_handler", None) success_handler = additional_details.pop("success_handler", None)
failure_handler = additional_details.pop("failure_handler", None) failure_handler = additional_details.pop("failure_handler", None)
additional_details["Event_Name"] = additional_details.pop("successful_event_name", "litellm.succes_query") additional_details["Event_Name"] = additional_details.pop("successful_event_name", "litellm.succes_query")
@ -225,6 +232,8 @@ def handle_success(*args, **kwargs):
if success_handler and callable(success_handler): if success_handler and callable(success_handler):
success_handler(args, kwargs) success_handler(args, kwargs)
pass pass
except:
pass
def exception_type(model, original_exception): def exception_type(model, original_exception):

View file

@ -1,6 +1,5 @@
openai openai
cohere cohere
func_timeout
anthropic anthropic
replicate replicate
pytest pytest

View file

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name='litellm', name='litellm',
version='0.1.206', version='0.1.207',
description='Library to easily interface with LLM API providers', description='Library to easily interface with LLM API providers',
author='BerriAI', author='BerriAI',
packages=[ packages=[
@ -11,11 +11,10 @@ setup(
install_requires=[ install_requires=[
'openai', 'openai',
'cohere', 'cohere',
'func_timeout',
'pytest', 'pytest',
'anthropic', 'anthropic',
'replicate', 'replicate',
'python-dotenv', 'python-dotenv',
'openai[datalib]' 'openai[datalib]',
], ],
) )