mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(router.py): enabling retrying with expo backoff (without tenacity) for router
This commit is contained in:
parent
98c45f1b4e
commit
59eaeba92a
9 changed files with 147 additions and 84 deletions
|
@ -2,7 +2,7 @@ from datetime import datetime
|
|||
from typing import Dict, List, Optional, Union, Literal
|
||||
import random, threading, time
|
||||
import litellm, openai
|
||||
import logging
|
||||
import logging, asyncio
|
||||
|
||||
class Router:
|
||||
"""
|
||||
|
@ -23,6 +23,8 @@ class Router:
|
|||
model_names: List = []
|
||||
cache_responses: bool = False
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
num_retries: int = 0
|
||||
tenacity = None
|
||||
|
||||
def __init__(self,
|
||||
model_list: Optional[list] = None,
|
||||
|
@ -31,7 +33,9 @@ class Router:
|
|||
redis_password: Optional[str] = None,
|
||||
cache_responses: bool = False,
|
||||
num_retries: Optional[int] = None,
|
||||
timeout: float = 600,
|
||||
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
|
||||
|
||||
if model_list:
|
||||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = self.model_list
|
||||
|
@ -39,6 +43,7 @@ class Router:
|
|||
if num_retries:
|
||||
self.num_retries = num_retries
|
||||
|
||||
litellm.request_timeout = timeout
|
||||
self.routing_strategy = routing_strategy
|
||||
### HEALTH CHECK THREAD ###
|
||||
if self.routing_strategy == "least-busy":
|
||||
|
@ -132,6 +137,37 @@ class Router:
|
|||
|
||||
raise ValueError("No models available.")
|
||||
|
||||
def retry_if_rate_limit_error(self, exception):
|
||||
return isinstance(exception, openai.RateLimitError)
|
||||
|
||||
def retry_if_api_error(self, exception):
|
||||
return isinstance(exception, openai.APIError)
|
||||
|
||||
async def async_function_with_retries(self, *args, **kwargs):
|
||||
# we'll backoff exponentially with each retry
|
||||
backoff_factor = 1
|
||||
original_exception = kwargs.pop("original_exception")
|
||||
original_function = kwargs.pop("original_function")
|
||||
for current_attempt in range(self.num_retries):
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
return await original_function(*args, **kwargs)
|
||||
|
||||
except openai.RateLimitError as e:
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
await asyncio.sleep(backoff_factor)
|
||||
|
||||
# increase backoff factor for next run
|
||||
backoff_factor *= 2
|
||||
|
||||
except openai.APIError as e:
|
||||
# on APIError we immediately retry without any wait, change this if necessary
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
# for any other exception types, don't retry
|
||||
raise e
|
||||
|
||||
def function_with_retries(self, *args, **kwargs):
|
||||
try:
|
||||
import tenacity
|
||||
|
@ -144,6 +180,9 @@ class Router:
|
|||
retry_info["attempts"] = retry_state.attempt_number
|
||||
retry_info["final_result"] = retry_state.outcome.result()
|
||||
|
||||
if 'model' not in kwargs or 'messages' not in kwargs:
|
||||
raise ValueError("'model' and 'messages' must be included as keyword arguments")
|
||||
|
||||
try:
|
||||
original_exception = kwargs.pop("original_exception")
|
||||
original_function = kwargs.pop("original_function")
|
||||
|
@ -157,7 +196,7 @@ class Router:
|
|||
reraise=True,
|
||||
after=after_callback)
|
||||
|
||||
return retryer(original_function, *args, **kwargs)
|
||||
return retryer(self.acompletion, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in function_with_retries: {e}\n\nRetry Info: {retry_info}")
|
||||
|
||||
|
@ -180,7 +219,6 @@ class Router:
|
|||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
|
||||
|
||||
|
||||
async def acompletion(self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
|
@ -197,7 +235,7 @@ class Router:
|
|||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.acompletion
|
||||
return self.function_with_retries(**kwargs)
|
||||
return await self.async_function_with_retries(**kwargs)
|
||||
|
||||
def text_completion(self,
|
||||
model: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue