fix(router.py): speed improvements to the router

This commit is contained in:
Krrish Dholakia 2023-11-27 17:35:02 -08:00
parent 8560794963
commit 04f745e314
4 changed files with 92 additions and 5 deletions

View file

@ -94,6 +94,7 @@ class Router:
# default litellm args
self.default_litellm_params = default_litellm_params
self.default_litellm_params["timeout"] = timeout
self.default_litellm_params["max_retries"] = 0
### HEALTH CHECK THREAD ###
@ -278,8 +279,8 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
model_group = kwargs.get("model")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
try:
response = await self.async_function_with_retries(*args, **kwargs)
self.print_verbose(f'Async Response: {response}')
@ -335,6 +336,8 @@ class Router:
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
backoff_factor = 1
original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
self.print_verbose(f"async function w/ retries: original_function - {original_function}")
num_retries = kwargs.pop("num_retries")
try:
@ -343,6 +346,11 @@ class Router:
return response
except Exception as e:
original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
or (openai.RateLimitError and fallbacks is not None)):
raise original_exception
### RETRY
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try:
@ -353,7 +361,7 @@ class Router:
return response
except openai.RateLimitError as e:
if num_retries > 0:
if num_retries > 0 and fallbacks is None:
# on RateLimitError we'll wait for an exponential time before trying again
await asyncio.sleep(backoff_factor)