fix(router.py): enabling retrying with expo backoff (without tenacity) for router

This commit is contained in:
Krrish Dholakia 2023-11-14 20:57:51 -08:00
parent 98c45f1b4e
commit 59eaeba92a
9 changed files with 147 additions and 84 deletions

View file

@ -385,6 +385,7 @@ from .exceptions import (
ContextWindowExceededError,
BudgetExceededError,
APIError,
Timeout
)
from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server

View file

@ -72,11 +72,11 @@ class AzureOpenAIConfig(OpenAIConfig):
top_p)
class AzureChatCompletion(BaseLLM):
_client_session: httpx.Client
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None:
super().__init__()
self._client_session = self.create_client_session()
def validate_environment(self, api_key, azure_ad_token):
headers = {
@ -105,6 +105,8 @@ class AzureChatCompletion(BaseLLM):
acompletion: bool = False,
headers: Optional[dict]=None):
super().completion()
if self._client_session is None:
self._client_session = self.create_client_session()
exception_mapping_worked = False
try:
if headers is None:
@ -157,15 +159,24 @@ class AzureChatCompletion(BaseLLM):
raise e
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async with httpx.AsyncClient(timeout=600) as client:
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
try:
response = await client.post(api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except Exception as e:
if httpx.TimeoutException:
raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
elif response and hasattr(response, "text"):
raise AzureOpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
raise AzureOpenAIError(status_code=500, message=f"{str(e)}")
def streaming(self,
logging_obj,

View file

@ -153,14 +153,11 @@ class OpenAITextCompletionConfig():
and v is not None}
class OpenAIChatCompletion(BaseLLM):
_client_session: httpx.Client
_aclient_session: httpx.AsyncClient
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None:
super().__init__()
self._client_session = self.create_client_session()
self._aclient_session = self.create_aclient_session()
self._num_retry_httpx_errors = 3 # httpx throws random errors - e.g. ReadError,
def validate_environment(self, api_key):
headers = {
@ -193,6 +190,8 @@ class OpenAIChatCompletion(BaseLLM):
logger_fn=None,
headers: Optional[dict]=None):
super().completion()
if self._client_session is None:
self._client_session = self.create_client_session()
exception_mapping_worked = False
try:
if headers is None:
@ -264,6 +263,8 @@ class OpenAIChatCompletion(BaseLLM):
data: dict, headers: dict,
model_response: ModelResponse):
kwargs = locals()
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
try:
response = await client.post(api_base, json=data, headers=headers)
@ -273,13 +274,9 @@ class OpenAIChatCompletion(BaseLLM):
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except httpx.ReadError or httpx.ReadTimeout:
if self._num_retry_httpx_errors > 0:
kwargs["original_function"] = self.acompletion
return self._retry_request(**kwargs)
else:
raise e
except Exception as e:
if httpx.TimeoutException:
raise OpenAIError(status_code=500, message="Request Timeout Error")
if response and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:

View file

@ -236,9 +236,6 @@ def mock_completion(model: str, messages: List, stream: Optional[bool] = False,
raise Exception("Mock completion response failed")
@client
@timeout( # type: ignore
600
) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout`
def completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create

View file

@ -2,7 +2,7 @@ from datetime import datetime
from typing import Dict, List, Optional, Union, Literal
import random, threading, time
import litellm, openai
import logging
import logging, asyncio
class Router:
"""
@ -23,6 +23,8 @@ class Router:
model_names: List = []
cache_responses: bool = False
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
num_retries: int = 0
tenacity = None
def __init__(self,
model_list: Optional[list] = None,
@ -31,7 +33,9 @@ class Router:
redis_password: Optional[str] = None,
cache_responses: bool = False,
num_retries: Optional[int] = None,
timeout: float = 600,
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
if model_list:
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list
@ -39,6 +43,7 @@ class Router:
if num_retries:
self.num_retries = num_retries
litellm.request_timeout = timeout
self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy":
@ -132,6 +137,37 @@ class Router:
raise ValueError("No models available.")
def retry_if_rate_limit_error(self, exception):
return isinstance(exception, openai.RateLimitError)
def retry_if_api_error(self, exception):
return isinstance(exception, openai.APIError)
async def async_function_with_retries(self, *args, **kwargs):
# we'll backoff exponentially with each retry
backoff_factor = 1
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function")
for current_attempt in range(self.num_retries):
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
return await original_function(*args, **kwargs)
except openai.RateLimitError as e:
# on RateLimitError we'll wait for an exponential time before trying again
await asyncio.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
except openai.APIError as e:
# on APIError we immediately retry without any wait, change this if necessary
pass
except Exception as e:
# for any other exception types, don't retry
raise e
def function_with_retries(self, *args, **kwargs):
try:
import tenacity
@ -144,6 +180,9 @@ class Router:
retry_info["attempts"] = retry_state.attempt_number
retry_info["final_result"] = retry_state.outcome.result()
if 'model' not in kwargs or 'messages' not in kwargs:
raise ValueError("'model' and 'messages' must be included as keyword arguments")
try:
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function")
@ -157,7 +196,7 @@ class Router:
reraise=True,
after=after_callback)
return retryer(original_function, *args, **kwargs)
return retryer(self.acompletion, *args, **kwargs)
except Exception as e:
raise Exception(f"Error in function_with_retries: {e}\n\nRetry Info: {retry_info}")
@ -180,7 +219,6 @@ class Router:
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
async def acompletion(self,
model: str,
messages: List[Dict[str, str]],
@ -197,7 +235,7 @@ class Router:
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.acompletion
return self.function_with_retries(**kwargs)
return await self.async_function_with_retries(**kwargs)
def text_completion(self,
model: str,

View file

@ -1,63 +1,69 @@
import sys, os
import traceback
from dotenv import load_dotenv
import copy
# import sys, os
# import traceback
# from dotenv import load_dotenv
# import copy
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
from litellm import Router
# load_dotenv()
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import asyncio
# from litellm import Router, Timeout
async def call_acompletion(semaphore, router: Router, input_data):
async with semaphore:
# Replace 'input_data' with appropriate parameters for acompletion
response = await router.acompletion(**input_data)
# Handle the response as needed
return response
async def main():
# Initialize the Router
model_list= [{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
},
}]
router = Router(model_list=model_list, num_retries=3)
# async def call_acompletion(semaphore, router: Router, input_data):
# async with semaphore:
# try:
# # Use asyncio.wait_for to set a timeout for the task
# response = await router.acompletion(**input_data)
# # Handle the response as needed
# return response
# except Timeout:
# print(f"Task timed out: {input_data}")
# return None # You may choose to return something else or raise an exception
# Create a semaphore with a capacity of 100
semaphore = asyncio.Semaphore(100)
# List to hold all task references
tasks = []
# async def main():
# # Initialize the Router
# model_list= [{
# "model_name": "gpt-3.5-turbo",
# "litellm_params": {
# "model": "gpt-3.5-turbo",
# "api_key": os.getenv("OPENAI_API_KEY"),
# },
# }, {
# "model_name": "gpt-3.5-turbo",
# "litellm_params": {
# "model": "azure/chatgpt-v-2",
# "api_key": os.getenv("AZURE_API_KEY"),
# "api_base": os.getenv("AZURE_API_BASE"),
# "api_version": os.getenv("AZURE_API_VERSION")
# },
# }, {
# "model_name": "gpt-3.5-turbo",
# "litellm_params": {
# "model": "azure/chatgpt-functioncalling",
# "api_key": os.getenv("AZURE_API_KEY"),
# "api_base": os.getenv("AZURE_API_BASE"),
# "api_version": os.getenv("AZURE_API_VERSION")
# },
# }]
# router = Router(model_list=model_list, num_retries=3, timeout=10)
# Launch 1000 tasks
for _ in range(100):
task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
tasks.append(task)
# # Create a semaphore with a capacity of 100
# semaphore = asyncio.Semaphore(100)
# Wait for all tasks to complete
responses = await asyncio.gather(*tasks)
# Process responses as needed
print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
# Run the main function
asyncio.run(main())
# # List to hold all task references
# tasks = []
# # Launch 1000 tasks
# for _ in range(1000):
# task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
# tasks.append(task)
# # Wait for all tasks to complete
# responses = await asyncio.gather(*tasks)
# # Process responses as needed
# print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
# # Run the main function
# asyncio.run(main())

View file

@ -246,7 +246,7 @@ def test_acompletion_on_router():
]
async def get_response():
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True)
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=10)
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response1: {response1}")
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)

View file

@ -3236,6 +3236,15 @@ def exception_type(
exception_type = type(original_exception).__name__
else:
exception_type = ""
if "Request Timeout Error" in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"APITimeoutError - Request timed out",
model=model,
llm_provider=custom_llm_provider
)
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai":
if "This model's maximum context length is" in error_str or "Request too large" in error_str:
exception_mapping_worked = True
@ -4081,8 +4090,12 @@ def exception_type(
response=original_exception.response
)
else:
# raise BadRequestError(message=str(original_exception), llm_provider=custom_llm_provider, model=model, response=original_exception.response)
raise original_exception
exception_mapping_worked = True
raise APIConnectionError(
message=f"{str(original_exception)}",
llm_provider=custom_llm_provider,
model=model
)
except Exception as e:
# LOGGING
exception_logging(

View file

@ -8,7 +8,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.8"
openai = "*"
openai = ">=1.0.0"
python-dotenv = ">=0.2.0"
tiktoken = ">=0.4.0"
importlib-metadata = ">=6.8.0"