mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(router.py): enabling retrying with expo backoff (without tenacity) for router
This commit is contained in:
parent
98c45f1b4e
commit
59eaeba92a
9 changed files with 147 additions and 84 deletions
|
@ -385,6 +385,7 @@ from .exceptions import (
|
|||
ContextWindowExceededError,
|
||||
BudgetExceededError,
|
||||
APIError,
|
||||
Timeout
|
||||
)
|
||||
from .budget_manager import BudgetManager
|
||||
from .proxy.proxy_cli import run_server
|
||||
|
|
|
@ -72,11 +72,11 @@ class AzureOpenAIConfig(OpenAIConfig):
|
|||
top_p)
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
_client_session: httpx.Client
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
_aclient_session: Optional[httpx.AsyncClient] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._client_session = self.create_client_session()
|
||||
|
||||
def validate_environment(self, api_key, azure_ad_token):
|
||||
headers = {
|
||||
|
@ -105,6 +105,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
acompletion: bool = False,
|
||||
headers: Optional[dict]=None):
|
||||
super().completion()
|
||||
if self._client_session is None:
|
||||
self._client_session = self.create_client_session()
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
if headers is None:
|
||||
|
@ -157,15 +159,24 @@ class AzureChatCompletion(BaseLLM):
|
|||
raise e
|
||||
|
||||
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
|
||||
async with httpx.AsyncClient(timeout=600) as client:
|
||||
if self._aclient_session is None:
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
client = self._aclient_session
|
||||
try:
|
||||
response = await client.post(api_base, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
|
||||
raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if httpx.TimeoutException:
|
||||
raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
|
||||
elif response and hasattr(response, "text"):
|
||||
raise AzureOpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
def streaming(self,
|
||||
logging_obj,
|
||||
|
|
|
@ -153,14 +153,11 @@ class OpenAITextCompletionConfig():
|
|||
and v is not None}
|
||||
|
||||
class OpenAIChatCompletion(BaseLLM):
|
||||
_client_session: httpx.Client
|
||||
_aclient_session: httpx.AsyncClient
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
_aclient_session: Optional[httpx.AsyncClient] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._client_session = self.create_client_session()
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
self._num_retry_httpx_errors = 3 # httpx throws random errors - e.g. ReadError,
|
||||
|
||||
def validate_environment(self, api_key):
|
||||
headers = {
|
||||
|
@ -193,6 +190,8 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
logger_fn=None,
|
||||
headers: Optional[dict]=None):
|
||||
super().completion()
|
||||
if self._client_session is None:
|
||||
self._client_session = self.create_client_session()
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
if headers is None:
|
||||
|
@ -264,6 +263,8 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
data: dict, headers: dict,
|
||||
model_response: ModelResponse):
|
||||
kwargs = locals()
|
||||
if self._aclient_session is None:
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
client = self._aclient_session
|
||||
try:
|
||||
response = await client.post(api_base, json=data, headers=headers)
|
||||
|
@ -273,13 +274,9 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
|
||||
## RESPONSE OBJECT
|
||||
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||
except httpx.ReadError or httpx.ReadTimeout:
|
||||
if self._num_retry_httpx_errors > 0:
|
||||
kwargs["original_function"] = self.acompletion
|
||||
return self._retry_request(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
except Exception as e:
|
||||
if httpx.TimeoutException:
|
||||
raise OpenAIError(status_code=500, message="Request Timeout Error")
|
||||
if response and hasattr(response, "text"):
|
||||
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
||||
else:
|
||||
|
|
|
@ -236,9 +236,6 @@ def mock_completion(model: str, messages: List, stream: Optional[bool] = False,
|
|||
raise Exception("Mock completion response failed")
|
||||
|
||||
@client
|
||||
@timeout( # type: ignore
|
||||
600
|
||||
) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout`
|
||||
def completion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
|
|
@ -2,7 +2,7 @@ from datetime import datetime
|
|||
from typing import Dict, List, Optional, Union, Literal
|
||||
import random, threading, time
|
||||
import litellm, openai
|
||||
import logging
|
||||
import logging, asyncio
|
||||
|
||||
class Router:
|
||||
"""
|
||||
|
@ -23,6 +23,8 @@ class Router:
|
|||
model_names: List = []
|
||||
cache_responses: bool = False
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
num_retries: int = 0
|
||||
tenacity = None
|
||||
|
||||
def __init__(self,
|
||||
model_list: Optional[list] = None,
|
||||
|
@ -31,7 +33,9 @@ class Router:
|
|||
redis_password: Optional[str] = None,
|
||||
cache_responses: bool = False,
|
||||
num_retries: Optional[int] = None,
|
||||
timeout: float = 600,
|
||||
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
|
||||
|
||||
if model_list:
|
||||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = self.model_list
|
||||
|
@ -39,6 +43,7 @@ class Router:
|
|||
if num_retries:
|
||||
self.num_retries = num_retries
|
||||
|
||||
litellm.request_timeout = timeout
|
||||
self.routing_strategy = routing_strategy
|
||||
### HEALTH CHECK THREAD ###
|
||||
if self.routing_strategy == "least-busy":
|
||||
|
@ -132,6 +137,37 @@ class Router:
|
|||
|
||||
raise ValueError("No models available.")
|
||||
|
||||
def retry_if_rate_limit_error(self, exception):
|
||||
return isinstance(exception, openai.RateLimitError)
|
||||
|
||||
def retry_if_api_error(self, exception):
|
||||
return isinstance(exception, openai.APIError)
|
||||
|
||||
async def async_function_with_retries(self, *args, **kwargs):
|
||||
# we'll backoff exponentially with each retry
|
||||
backoff_factor = 1
|
||||
original_exception = kwargs.pop("original_exception")
|
||||
original_function = kwargs.pop("original_function")
|
||||
for current_attempt in range(self.num_retries):
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
return await original_function(*args, **kwargs)
|
||||
|
||||
except openai.RateLimitError as e:
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
await asyncio.sleep(backoff_factor)
|
||||
|
||||
# increase backoff factor for next run
|
||||
backoff_factor *= 2
|
||||
|
||||
except openai.APIError as e:
|
||||
# on APIError we immediately retry without any wait, change this if necessary
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
# for any other exception types, don't retry
|
||||
raise e
|
||||
|
||||
def function_with_retries(self, *args, **kwargs):
|
||||
try:
|
||||
import tenacity
|
||||
|
@ -144,6 +180,9 @@ class Router:
|
|||
retry_info["attempts"] = retry_state.attempt_number
|
||||
retry_info["final_result"] = retry_state.outcome.result()
|
||||
|
||||
if 'model' not in kwargs or 'messages' not in kwargs:
|
||||
raise ValueError("'model' and 'messages' must be included as keyword arguments")
|
||||
|
||||
try:
|
||||
original_exception = kwargs.pop("original_exception")
|
||||
original_function = kwargs.pop("original_function")
|
||||
|
@ -157,7 +196,7 @@ class Router:
|
|||
reraise=True,
|
||||
after=after_callback)
|
||||
|
||||
return retryer(original_function, *args, **kwargs)
|
||||
return retryer(self.acompletion, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in function_with_retries: {e}\n\nRetry Info: {retry_info}")
|
||||
|
||||
|
@ -180,7 +219,6 @@ class Router:
|
|||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
|
||||
|
||||
|
||||
async def acompletion(self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
|
@ -197,7 +235,7 @@ class Router:
|
|||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.acompletion
|
||||
return self.function_with_retries(**kwargs)
|
||||
return await self.async_function_with_retries(**kwargs)
|
||||
|
||||
def text_completion(self,
|
||||
model: str,
|
||||
|
|
|
@ -1,63 +1,69 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
import copy
|
||||
# import sys, os
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
# import copy
|
||||
|
||||
load_dotenv()
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import asyncio
|
||||
from litellm import Router
|
||||
# load_dotenv()
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import asyncio
|
||||
# from litellm import Router, Timeout
|
||||
|
||||
async def call_acompletion(semaphore, router: Router, input_data):
|
||||
async with semaphore:
|
||||
# Replace 'input_data' with appropriate parameters for acompletion
|
||||
response = await router.acompletion(**input_data)
|
||||
# Handle the response as needed
|
||||
return response
|
||||
|
||||
async def main():
|
||||
# Initialize the Router
|
||||
model_list= [{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION")
|
||||
},
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION")
|
||||
},
|
||||
}]
|
||||
router = Router(model_list=model_list, num_retries=3)
|
||||
# async def call_acompletion(semaphore, router: Router, input_data):
|
||||
# async with semaphore:
|
||||
# try:
|
||||
# # Use asyncio.wait_for to set a timeout for the task
|
||||
# response = await router.acompletion(**input_data)
|
||||
# # Handle the response as needed
|
||||
# return response
|
||||
# except Timeout:
|
||||
# print(f"Task timed out: {input_data}")
|
||||
# return None # You may choose to return something else or raise an exception
|
||||
|
||||
# Create a semaphore with a capacity of 100
|
||||
semaphore = asyncio.Semaphore(100)
|
||||
|
||||
# List to hold all task references
|
||||
tasks = []
|
||||
# async def main():
|
||||
# # Initialize the Router
|
||||
# model_list= [{
|
||||
# "model_name": "gpt-3.5-turbo",
|
||||
# "litellm_params": {
|
||||
# "model": "gpt-3.5-turbo",
|
||||
# "api_key": os.getenv("OPENAI_API_KEY"),
|
||||
# },
|
||||
# }, {
|
||||
# "model_name": "gpt-3.5-turbo",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION")
|
||||
# },
|
||||
# }, {
|
||||
# "model_name": "gpt-3.5-turbo",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/chatgpt-functioncalling",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION")
|
||||
# },
|
||||
# }]
|
||||
# router = Router(model_list=model_list, num_retries=3, timeout=10)
|
||||
|
||||
# Launch 1000 tasks
|
||||
for _ in range(100):
|
||||
task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
|
||||
tasks.append(task)
|
||||
# # Create a semaphore with a capacity of 100
|
||||
# semaphore = asyncio.Semaphore(100)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
responses = await asyncio.gather(*tasks)
|
||||
# Process responses as needed
|
||||
print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
|
||||
# Run the main function
|
||||
asyncio.run(main())
|
||||
# # List to hold all task references
|
||||
# tasks = []
|
||||
|
||||
# # Launch 1000 tasks
|
||||
# for _ in range(1000):
|
||||
# task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
|
||||
# tasks.append(task)
|
||||
|
||||
# # Wait for all tasks to complete
|
||||
# responses = await asyncio.gather(*tasks)
|
||||
# # Process responses as needed
|
||||
# print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
|
||||
# # Run the main function
|
||||
# asyncio.run(main())
|
||||
|
|
|
@ -246,7 +246,7 @@ def test_acompletion_on_router():
|
|||
]
|
||||
|
||||
async def get_response():
|
||||
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True)
|
||||
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=10)
|
||||
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
print(f"response1: {response1}")
|
||||
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
|
|
|
@ -3236,6 +3236,15 @@ def exception_type(
|
|||
exception_type = type(original_exception).__name__
|
||||
else:
|
||||
exception_type = ""
|
||||
|
||||
if "Request Timeout Error" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"APITimeoutError - Request timed out",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai":
|
||||
if "This model's maximum context length is" in error_str or "Request too large" in error_str:
|
||||
exception_mapping_worked = True
|
||||
|
@ -4081,8 +4090,12 @@ def exception_type(
|
|||
response=original_exception.response
|
||||
)
|
||||
else:
|
||||
# raise BadRequestError(message=str(original_exception), llm_provider=custom_llm_provider, model=model, response=original_exception.response)
|
||||
raise original_exception
|
||||
exception_mapping_worked = True
|
||||
raise APIConnectionError(
|
||||
message=f"{str(original_exception)}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model
|
||||
)
|
||||
except Exception as e:
|
||||
# LOGGING
|
||||
exception_logging(
|
||||
|
|
|
@ -8,7 +8,7 @@ readme = "README.md"
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8"
|
||||
openai = "*"
|
||||
openai = ">=1.0.0"
|
||||
python-dotenv = ">=0.2.0"
|
||||
tiktoken = ">=0.4.0"
|
||||
importlib-metadata = ">=6.8.0"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue