fix(router.py): enabling retrying with expo backoff (without tenacity) for router

This commit is contained in:
Krrish Dholakia 2023-11-14 20:57:51 -08:00
parent 98c45f1b4e
commit 59eaeba92a
9 changed files with 147 additions and 84 deletions

View file

@ -385,6 +385,7 @@ from .exceptions import (
ContextWindowExceededError, ContextWindowExceededError,
BudgetExceededError, BudgetExceededError,
APIError, APIError,
Timeout
) )
from .budget_manager import BudgetManager from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server from .proxy.proxy_cli import run_server

View file

@ -72,11 +72,11 @@ class AzureOpenAIConfig(OpenAIConfig):
top_p) top_p)
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
_client_session: httpx.Client _client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._client_session = self.create_client_session()
def validate_environment(self, api_key, azure_ad_token): def validate_environment(self, api_key, azure_ad_token):
headers = { headers = {
@ -105,6 +105,8 @@ class AzureChatCompletion(BaseLLM):
acompletion: bool = False, acompletion: bool = False,
headers: Optional[dict]=None): headers: Optional[dict]=None):
super().completion() super().completion()
if self._client_session is None:
self._client_session = self.create_client_session()
exception_mapping_worked = False exception_mapping_worked = False
try: try:
if headers is None: if headers is None:
@ -157,15 +159,24 @@ class AzureChatCompletion(BaseLLM):
raise e raise e
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse): async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async with httpx.AsyncClient(timeout=600) as client: if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
try:
response = await client.post(api_base, json=data, headers=headers) response = await client.post(api_base, json=data, headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text) raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT ## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except Exception as e:
if httpx.TimeoutException:
raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
elif response and hasattr(response, "text"):
raise AzureOpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
raise AzureOpenAIError(status_code=500, message=f"{str(e)}")
def streaming(self, def streaming(self,
logging_obj, logging_obj,

View file

@ -153,14 +153,11 @@ class OpenAITextCompletionConfig():
and v is not None} and v is not None}
class OpenAIChatCompletion(BaseLLM): class OpenAIChatCompletion(BaseLLM):
_client_session: httpx.Client _client_session: Optional[httpx.Client] = None
_aclient_session: httpx.AsyncClient _aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._client_session = self.create_client_session()
self._aclient_session = self.create_aclient_session()
self._num_retry_httpx_errors = 3 # httpx throws random errors - e.g. ReadError,
def validate_environment(self, api_key): def validate_environment(self, api_key):
headers = { headers = {
@ -193,6 +190,8 @@ class OpenAIChatCompletion(BaseLLM):
logger_fn=None, logger_fn=None,
headers: Optional[dict]=None): headers: Optional[dict]=None):
super().completion() super().completion()
if self._client_session is None:
self._client_session = self.create_client_session()
exception_mapping_worked = False exception_mapping_worked = False
try: try:
if headers is None: if headers is None:
@ -264,6 +263,8 @@ class OpenAIChatCompletion(BaseLLM):
data: dict, headers: dict, data: dict, headers: dict,
model_response: ModelResponse): model_response: ModelResponse):
kwargs = locals() kwargs = locals()
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session client = self._aclient_session
try: try:
response = await client.post(api_base, json=data, headers=headers) response = await client.post(api_base, json=data, headers=headers)
@ -273,13 +274,9 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except httpx.ReadError or httpx.ReadTimeout:
if self._num_retry_httpx_errors > 0:
kwargs["original_function"] = self.acompletion
return self._retry_request(**kwargs)
else:
raise e
except Exception as e: except Exception as e:
if httpx.TimeoutException:
raise OpenAIError(status_code=500, message="Request Timeout Error")
if response and hasattr(response, "text"): if response and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else: else:

View file

@ -236,9 +236,6 @@ def mock_completion(model: str, messages: List, stream: Optional[bool] = False,
raise Exception("Mock completion response failed") raise Exception("Mock completion response failed")
@client @client
@timeout( # type: ignore
600
) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout`
def completion( def completion(
model: str, model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create

View file

@ -2,7 +2,7 @@ from datetime import datetime
from typing import Dict, List, Optional, Union, Literal from typing import Dict, List, Optional, Union, Literal
import random, threading, time import random, threading, time
import litellm, openai import litellm, openai
import logging import logging, asyncio
class Router: class Router:
""" """
@ -23,6 +23,8 @@ class Router:
model_names: List = [] model_names: List = []
cache_responses: bool = False cache_responses: bool = False
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
num_retries: int = 0
tenacity = None
def __init__(self, def __init__(self,
model_list: Optional[list] = None, model_list: Optional[list] = None,
@ -31,7 +33,9 @@ class Router:
redis_password: Optional[str] = None, redis_password: Optional[str] = None,
cache_responses: bool = False, cache_responses: bool = False,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
timeout: float = 600,
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None: routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
if model_list: if model_list:
self.set_model_list(model_list) self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list self.healthy_deployments: List = self.model_list
@ -39,6 +43,7 @@ class Router:
if num_retries: if num_retries:
self.num_retries = num_retries self.num_retries = num_retries
litellm.request_timeout = timeout
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ### ### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy": if self.routing_strategy == "least-busy":
@ -132,6 +137,37 @@ class Router:
raise ValueError("No models available.") raise ValueError("No models available.")
def retry_if_rate_limit_error(self, exception):
return isinstance(exception, openai.RateLimitError)
def retry_if_api_error(self, exception):
return isinstance(exception, openai.APIError)
async def async_function_with_retries(self, *args, **kwargs):
# we'll backoff exponentially with each retry
backoff_factor = 1
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function")
for current_attempt in range(self.num_retries):
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
return await original_function(*args, **kwargs)
except openai.RateLimitError as e:
# on RateLimitError we'll wait for an exponential time before trying again
await asyncio.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
except openai.APIError as e:
# on APIError we immediately retry without any wait, change this if necessary
pass
except Exception as e:
# for any other exception types, don't retry
raise e
def function_with_retries(self, *args, **kwargs): def function_with_retries(self, *args, **kwargs):
try: try:
import tenacity import tenacity
@ -144,6 +180,9 @@ class Router:
retry_info["attempts"] = retry_state.attempt_number retry_info["attempts"] = retry_state.attempt_number
retry_info["final_result"] = retry_state.outcome.result() retry_info["final_result"] = retry_state.outcome.result()
if 'model' not in kwargs or 'messages' not in kwargs:
raise ValueError("'model' and 'messages' must be included as keyword arguments")
try: try:
original_exception = kwargs.pop("original_exception") original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function") original_function = kwargs.pop("original_function")
@ -157,7 +196,7 @@ class Router:
reraise=True, reraise=True,
after=after_callback) after=after_callback)
return retryer(original_function, *args, **kwargs) return retryer(self.acompletion, *args, **kwargs)
except Exception as e: except Exception as e:
raise Exception(f"Error in function_with_retries: {e}\n\nRetry Info: {retry_info}") raise Exception(f"Error in function_with_retries: {e}\n\nRetry Info: {retry_info}")
@ -180,7 +219,6 @@ class Router:
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
async def acompletion(self, async def acompletion(self,
model: str, model: str,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
@ -197,7 +235,7 @@ class Router:
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["original_exception"] = e kwargs["original_exception"] = e
kwargs["original_function"] = self.acompletion kwargs["original_function"] = self.acompletion
return self.function_with_retries(**kwargs) return await self.async_function_with_retries(**kwargs)
def text_completion(self, def text_completion(self,
model: str, model: str,

View file

@ -1,63 +1,69 @@
import sys, os # import sys, os
import traceback # import traceback
from dotenv import load_dotenv # from dotenv import load_dotenv
import copy # import copy
load_dotenv() # load_dotenv()
sys.path.insert( # sys.path.insert(
0, os.path.abspath("../..") # 0, os.path.abspath("../..")
) # Adds the parent directory to the system path # ) # Adds the parent directory to the system path
import asyncio # import asyncio
from litellm import Router # from litellm import Router, Timeout
async def call_acompletion(semaphore, router: Router, input_data):
async with semaphore:
# Replace 'input_data' with appropriate parameters for acompletion
response = await router.acompletion(**input_data)
# Handle the response as needed
return response
async def main(): # async def call_acompletion(semaphore, router: Router, input_data):
# Initialize the Router # async with semaphore:
model_list= [{ # try:
"model_name": "gpt-3.5-turbo", # # Use asyncio.wait_for to set a timeout for the task
"litellm_params": { # response = await router.acompletion(**input_data)
"model": "gpt-3.5-turbo", # # Handle the response as needed
"api_key": os.getenv("OPENAI_API_KEY"), # return response
}, # except Timeout:
}, { # print(f"Task timed out: {input_data}")
"model_name": "gpt-3.5-turbo", # return None # You may choose to return something else or raise an exception
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
}]
router = Router(model_list=model_list, num_retries=3)
# Create a semaphore with a capacity of 100
semaphore = asyncio.Semaphore(100)
# List to hold all task references # async def main():
tasks = [] # # Initialize the Router
# model_list= [{
# "model_name": "gpt-3.5-turbo",
# "litellm_params": {
# "model": "gpt-3.5-turbo",
# "api_key": os.getenv("OPENAI_API_KEY"),
# },
# }, {
# "model_name": "gpt-3.5-turbo",
# "litellm_params": {
# "model": "azure/chatgpt-v-2",
# "api_key": os.getenv("AZURE_API_KEY"),
# "api_base": os.getenv("AZURE_API_BASE"),
# "api_version": os.getenv("AZURE_API_VERSION")
# },
# }, {
# "model_name": "gpt-3.5-turbo",
# "litellm_params": {
# "model": "azure/chatgpt-functioncalling",
# "api_key": os.getenv("AZURE_API_KEY"),
# "api_base": os.getenv("AZURE_API_BASE"),
# "api_version": os.getenv("AZURE_API_VERSION")
# },
# }]
# router = Router(model_list=model_list, num_retries=3, timeout=10)
# Launch 1000 tasks # # Create a semaphore with a capacity of 100
for _ in range(100): # semaphore = asyncio.Semaphore(100)
task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
tasks.append(task)
# Wait for all tasks to complete # # List to hold all task references
responses = await asyncio.gather(*tasks) # tasks = []
# Process responses as needed
print(f"NUMBER OF COMPLETED TASKS: {len(responses)}") # # Launch 1000 tasks
# Run the main function # for _ in range(1000):
asyncio.run(main()) # task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
# tasks.append(task)
# # Wait for all tasks to complete
# responses = await asyncio.gather(*tasks)
# # Process responses as needed
# print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
# # Run the main function
# asyncio.run(main())

View file

@ -246,7 +246,7 @@ def test_acompletion_on_router():
] ]
async def get_response(): async def get_response():
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True) router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=10)
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages) response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response1: {response1}") print(f"response1: {response1}")
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages) response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)

View file

@ -3236,6 +3236,15 @@ def exception_type(
exception_type = type(original_exception).__name__ exception_type = type(original_exception).__name__
else: else:
exception_type = "" exception_type = ""
if "Request Timeout Error" in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"APITimeoutError - Request timed out",
model=model,
llm_provider=custom_llm_provider
)
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai": if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai":
if "This model's maximum context length is" in error_str or "Request too large" in error_str: if "This model's maximum context length is" in error_str or "Request too large" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -4081,8 +4090,12 @@ def exception_type(
response=original_exception.response response=original_exception.response
) )
else: else:
# raise BadRequestError(message=str(original_exception), llm_provider=custom_llm_provider, model=model, response=original_exception.response) exception_mapping_worked = True
raise original_exception raise APIConnectionError(
message=f"{str(original_exception)}",
llm_provider=custom_llm_provider,
model=model
)
except Exception as e: except Exception as e:
# LOGGING # LOGGING
exception_logging( exception_logging(

View file

@ -8,7 +8,7 @@ readme = "README.md"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8" python = "^3.8"
openai = "*" openai = ">=1.0.0"
python-dotenv = ">=0.2.0" python-dotenv = ">=0.2.0"
tiktoken = ">=0.4.0" tiktoken = ">=0.4.0"
importlib-metadata = ">=6.8.0" importlib-metadata = ">=6.8.0"