fix(main.py): keep client consistent across calls + exponential backoff retry on ratelimit errors

This commit is contained in:
Krrish Dholakia 2023-11-14 16:25:36 -08:00
parent 5963d9d283
commit a7222f257c
9 changed files with 239 additions and 131 deletions

View file

@ -19,6 +19,7 @@ telemetry = True
max_tokens = 256 # OpenAI Defaults max_tokens = 256 # OpenAI Defaults
drop_params = False drop_params = False
retry = True retry = True
request_timeout: float = 600
api_key: Optional[str] = None api_key: Optional[str] = None
openai_key: Optional[str] = None openai_key: Optional[str] = None
azure_key: Optional[str] = None azure_key: Optional[str] = None
@ -46,6 +47,7 @@ _current_cost = 0 # private variable, used if max budget is set
error_logs: Dict = {} error_logs: Dict = {}
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
client_session: Optional[httpx.Client] = None client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None model_fallbacks: Optional[List] = None
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
num_retries: Optional[int] = None num_retries: Optional[int] = None

View file

@ -105,7 +105,7 @@ class APIError(APIError): # type: ignore
super().__init__( super().__init__(
self.message, self.message,
request=request, # type: ignore request=request, # type: ignore
body=None, body=None
) )
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made

View file

@ -9,13 +9,25 @@ class BaseLLM:
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session
else: else:
_client_session = httpx.Client(timeout=600) _client_session = httpx.Client(timeout=litellm.request_timeout)
return _client_session return _client_session
def create_aclient_session(self):
if litellm.aclient_session:
_aclient_session = litellm.aclient_session
else:
_aclient_session = httpx.AsyncClient(timeout=litellm.request_timeout)
return _aclient_session
def __exit__(self): def __exit__(self):
if hasattr(self, '_client_session'): if hasattr(self, '_client_session'):
self._client_session.close() self._client_session.close()
async def __aexit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, '_aclient_session'):
await self._aclient_session.aclose()
def validate_environment(self): # set up the environment required to run the model def validate_environment(self): # set up the environment required to run the model
pass pass

View file

@ -154,10 +154,12 @@ class OpenAITextCompletionConfig():
class OpenAIChatCompletion(BaseLLM): class OpenAIChatCompletion(BaseLLM):
_client_session: httpx.Client _client_session: httpx.Client
_aclient_session: httpx.AsyncClient
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._client_session = self.create_client_session() self._client_session = self.create_client_session()
self._aclient_session = self.create_aclient_session()
def validate_environment(self, api_key): def validate_environment(self, api_key):
headers = { headers = {
@ -251,15 +253,15 @@ class OpenAIChatCompletion(BaseLLM):
api_base: str, api_base: str,
data: dict, headers: dict, data: dict, headers: dict,
model_response: ModelResponse): model_response: ModelResponse):
async with httpx.AsyncClient(timeout=600) as client: client = self._aclient_session
response = await client.post(api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT response = await client.post(api_base, json=data, headers=headers)
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def streaming(self, def streaming(self,
logging_obj, logging_obj,
@ -290,8 +292,7 @@ class OpenAIChatCompletion(BaseLLM):
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str): model: str):
client = httpx.AsyncClient() async with self._aclient_session.stream(
async with client.stream(
url=f"{api_base}", url=f"{api_base}",
json=data, json=data,
headers=headers, headers=headers,

View file

@ -6,11 +6,14 @@ import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
import litellm import litellm
import httpx
class VertexAIError(Exception): class VertexAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs

View file

@ -108,6 +108,7 @@ class Completions():
response = completion(model=model, messages=messages, **self.params) response = completion(model=model, messages=messages, **self.params)
return response return response
@client
async def acompletion(*args, **kwargs): async def acompletion(*args, **kwargs):
""" """
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -149,63 +150,51 @@ async def acompletion(*args, **kwargs):
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs["messages"]
### INITIALIZE LOGGING OBJECT ###
kwargs["litellm_call_id"] = str(uuid.uuid4())
start_time = datetime.datetime.now()
logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time)
### PASS ARGS TO COMPLETION ### ### PASS ARGS TO COMPLETION ###
kwargs["litellm_logging_obj"] = logging_obj
kwargs["acompletion"] = True kwargs["acompletion"] = True
kwargs["model"] = model try:
kwargs["messages"] = messages # Use a partial function to pass your keyword arguments
# Use a partial function to pass your keyword arguments func = partial(completion, *args, **kwargs)
func = partial(completion, *args, **kwargs)
# Add the context to the function # Add the context to the function
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if (custom_llm_provider == "openai" if (custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all. or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False): if kwargs.get("stream", False):
response = completion(*args, **kwargs) response = completion(*args, **kwargs)
else:
# Await normally
init_response = completion(*args, **kwargs)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse):
response = init_response
else: else:
response = await init_response # Await normally
else: init_response = completion(*args, **kwargs)
# Call the synchronous function using run_in_executor if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = await loop.run_in_executor(None, func_with_context) response = init_response
if kwargs.get("stream", False): # return an async generator else:
# do not change this response = await init_response
# for stream = True, always return an async generator else:
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193 # Call the synchronous function using run_in_executor
# return response response = await loop.run_in_executor(None, func_with_context)
return( if kwargs.get("stream", False): # return an async generator
line # do not change this
async for line in response # for stream = True, always return an async generator
) # See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
else: # return response
end_time = datetime.datetime.now() return(
# [OPTIONAL] ADD TO CACHE line
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object async for line in response
litellm.cache.add_cache(response, *args, **kwargs) )
else:
return response
except Exception as e:
## Map to OpenAI Exception
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
# LOG SUCCESS
logging_obj.success_handler(response, start_time, end_time)
# RETURN RESULT
response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
return response
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs): def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
""" """
@ -1420,8 +1409,14 @@ def completion_with_retries(*args, **kwargs):
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}") raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
num_retries = kwargs.pop("num_retries", 3) num_retries = kwargs.pop("num_retries", 3)
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True) retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
return retryer(completion, *args, **kwargs) original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
elif retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
return retryer(original_function, *args, **kwargs)
def batch_completion( def batch_completion(

View file

@ -0,0 +1,47 @@
import sys, os
import traceback
from dotenv import load_dotenv
import copy
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
from litellm import Router
async def call_acompletion(semaphore, router: Router, input_data):
async with semaphore:
# Replace 'input_data' with appropriate parameters for acompletion
response = await router.acompletion(**input_data)
# Handle the response as needed
return response
async def main():
# Initialize the Router
model_list= [{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
}]
router = Router(model_list=model_list, num_retries=3)
# Create a semaphore with a capacity of 100
semaphore = asyncio.Semaphore(100)
# List to hold all task references
tasks = []
# Launch 1000 tasks
for _ in range(1000):
task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
tasks.append(task)
# Wait for all tasks to complete
responses = await asyncio.gather(*tasks)
# Process responses as needed
# Run the main function
asyncio.run(main())

View file

@ -317,7 +317,7 @@ def test_function_calling_on_router():
except Exception as e: except Exception as e:
print(f"An exception occurred: {e}") print(f"An exception occurred: {e}")
test_function_calling_on_router() # test_function_calling_on_router()
def test_aembedding_on_router(): def test_aembedding_on_router():
try: try:

View file

@ -456,6 +456,7 @@ from enum import Enum
class CallTypes(Enum): class CallTypes(Enum):
embedding = 'embedding' embedding = 'embedding'
completion = 'completion' completion = 'completion'
acompletion = 'acompletion'
# Logging function -> log the exact model details + what's being sent | Non-Blocking # Logging function -> log the exact model details + what's being sent | Non-Blocking
class Logging: class Logging:
@ -984,7 +985,7 @@ def exception_logging(
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function): def client(original_function):
global liteDebuggerClient, get_all_keys global liteDebuggerClient, get_all_keys
import inspect
def function_setup( def function_setup(
start_time, *args, **kwargs start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -1036,7 +1037,7 @@ def client(original_function):
# INIT LOGGER - for user-specified integrations # INIT LOGGER - for user-specified integrations
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
call_type = original_function.__name__ call_type = original_function.__name__
if call_type == CallTypes.completion.value: if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
if len(args) > 1: if len(args) > 1:
messages = args[1] messages = args[1]
elif kwargs.get("messages", None): elif kwargs.get("messages", None):
@ -1183,7 +1184,107 @@ def client(original_function):
): # make it easy to get to the debugger logs if you've initialized it ): # make it easy to get to the debugger logs if you've initialized it
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e raise e
return wrapper
async def wrapper_async(*args, **kwargs):
start_time = datetime.datetime.now()
result = None
logging_obj = kwargs.get("litellm_logging_obj", None)
# only set litellm_call_id if its not in kwargs
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except:
raise ValueError("model param not passed in.")
try:
if logging_obj is None:
logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK BUDGET
if litellm.max_budget:
if litellm._current_cost > litellm.max_budget:
raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget)
# [OPTIONAL] CHECK CACHE
print_verbose(f"litellm.cache: {litellm.cache}")
print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}")
# if caching is false, don't run this
if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function
# checking cache
if (litellm.cache != None):
print_verbose(f"Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None:
print_verbose(f"Cache Hit!")
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
else:
return cached_result
# MODEL CALL
result = original_function(*args, **kwargs)
end_time = datetime.datetime.now()
if "stream" in kwargs and kwargs["stream"] == True:
if "complete_response" in kwargs and kwargs["complete_response"] == True:
chunks = []
for idx, chunk in enumerate(result):
chunks.append(chunk)
return litellm.stream_chunk_builder(chunks)
else:
return result
result = await result
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
logging_obj.success_handler(result, start_time, end_time)
# RETURN RESULT
return result
except Exception as e:
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value:
num_retries = (
kwargs.get("num_retries", None)
or litellm.num_retries
or None
)
litellm.num_retries = None # set retries to None to prevent infinite loops
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {})
if num_retries:
kwargs["num_retries"] = num_retries
kwargs["original_function"] = original_function
if (isinstance(e, openai.RateLimitError)): # rate limiting specific error
kwargs["retry_strategy"] = "exponential_backoff_retry"
elif (isinstance(e, openai.APIError)): # generic api error
kwargs["retry_strategy"] = "constant_retry"
return litellm.completion_with_retries(*args, **kwargs)
elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict:
if len(args) > 0:
args[0] = context_window_fallback_dict[model]
else:
kwargs["model"] = context_window_fallback_dict[model]
return original_function(*args, **kwargs)
traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
if logging_obj:
threading.Thread(target=logging_obj.failure_handler, args=(e, traceback_exception, start_time, end_time)).start()
raise e
# Use httpx to determine if the original function is a coroutine
is_coroutine = inspect.iscoroutinefunction(original_function)
# Return the appropriate wrapper based on the original function type
if is_coroutine:
return wrapper_async
else:
return wrapper
####### USAGE CALCULATOR ################ ####### USAGE CALCULATOR ################
@ -3116,31 +3217,13 @@ def exception_type(
print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa
print() # noqa print() # noqa
try: try:
if isinstance(original_exception, OriginalError): if model:
# Handle the OpenAIError
exception_mapping_worked = True
if custom_llm_provider == "openrouter":
if original_exception.http_status == 413:
raise BadRequestError(
message=str(original_exception),
model=model,
llm_provider="openrouter"
)
original_exception.llm_provider = "openrouter"
if "This model's maximum context length is" in original_exception._message:
raise ContextWindowExceededError(
message=str(original_exception),
model=model,
llm_provider=original_exception.llm_provider
)
raise original_exception
elif model:
error_str = str(original_exception) error_str = str(original_exception)
if isinstance(original_exception, BaseException): if isinstance(original_exception, BaseException):
exception_type = type(original_exception).__name__ exception_type = type(original_exception).__name__
else: else:
exception_type = "" exception_type = ""
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai": if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai":
if "This model's maximum context length is" in error_str or "Request too large" in error_str: if "This model's maximum context length is" in error_str or "Request too large" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise ContextWindowExceededError( raise ContextWindowExceededError(
@ -3191,6 +3274,14 @@ def exception_type(
llm_provider="openai", llm_provider="openai",
response=original_exception.response response=original_exception.response
) )
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"OpenAIException - {original_exception.message}",
model=model,
llm_provider="openai",
response=original_exception.response
)
elif original_exception.status_code == 504: # gateway timeout error elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
@ -3968,49 +4059,6 @@ def exception_type(
model=model, model=model,
request=original_exception.request request=original_exception.request
) )
elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk":
if hasattr(original_exception, "status_code"):
exception_mapping_worked = True
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"CustomOpenAIException - {original_exception.message}",
llm_provider="custom_openai",
model=model
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
message=f"CustomOpenAIException - {original_exception.message}",
model=model,
llm_provider="custom_openai",
request=original_exception.request
)
if original_exception.status_code == 422:
exception_mapping_worked = True
raise BadRequestError(
message=f"CustomOpenAIException - {original_exception.message}",
model=model,
llm_provider="custom_openai",
response=original_exception.response
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"CustomOpenAIException - {original_exception.message}",
model=model,
llm_provider="custom_openai",
response=original_exception.response
)
else:
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
message=f"CustomOpenAIException - {original_exception.message}",
llm_provider="custom_openai",
model=model,
request=original_exception.request
)
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(