mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(main.py): keep client consistent across calls + exponential backoff retry on ratelimit errors
This commit is contained in:
parent
5963d9d283
commit
a7222f257c
9 changed files with 239 additions and 131 deletions
|
@ -19,6 +19,7 @@ telemetry = True
|
||||||
max_tokens = 256 # OpenAI Defaults
|
max_tokens = 256 # OpenAI Defaults
|
||||||
drop_params = False
|
drop_params = False
|
||||||
retry = True
|
retry = True
|
||||||
|
request_timeout: float = 600
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
openai_key: Optional[str] = None
|
openai_key: Optional[str] = None
|
||||||
azure_key: Optional[str] = None
|
azure_key: Optional[str] = None
|
||||||
|
@ -46,6 +47,7 @@ _current_cost = 0 # private variable, used if max budget is set
|
||||||
error_logs: Dict = {}
|
error_logs: Dict = {}
|
||||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||||
client_session: Optional[httpx.Client] = None
|
client_session: Optional[httpx.Client] = None
|
||||||
|
aclient_session: Optional[httpx.AsyncClient] = None
|
||||||
model_fallbacks: Optional[List] = None
|
model_fallbacks: Optional[List] = None
|
||||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
num_retries: Optional[int] = None
|
num_retries: Optional[int] = None
|
||||||
|
|
|
@ -105,7 +105,7 @@ class APIError(APIError): # type: ignore
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message,
|
self.message,
|
||||||
request=request, # type: ignore
|
request=request, # type: ignore
|
||||||
body=None,
|
body=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# raised if an invalid request (not get, delete, put, post) is made
|
# raised if an invalid request (not get, delete, put, post) is made
|
||||||
|
|
|
@ -9,13 +9,25 @@ class BaseLLM:
|
||||||
if litellm.client_session:
|
if litellm.client_session:
|
||||||
_client_session = litellm.client_session
|
_client_session = litellm.client_session
|
||||||
else:
|
else:
|
||||||
_client_session = httpx.Client(timeout=600)
|
_client_session = httpx.Client(timeout=litellm.request_timeout)
|
||||||
|
|
||||||
return _client_session
|
return _client_session
|
||||||
|
|
||||||
|
def create_aclient_session(self):
|
||||||
|
if litellm.aclient_session:
|
||||||
|
_aclient_session = litellm.aclient_session
|
||||||
|
else:
|
||||||
|
_aclient_session = httpx.AsyncClient(timeout=litellm.request_timeout)
|
||||||
|
|
||||||
|
return _aclient_session
|
||||||
|
|
||||||
def __exit__(self):
|
def __exit__(self):
|
||||||
if hasattr(self, '_client_session'):
|
if hasattr(self, '_client_session'):
|
||||||
self._client_session.close()
|
self._client_session.close()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if hasattr(self, '_aclient_session'):
|
||||||
|
await self._aclient_session.aclose()
|
||||||
|
|
||||||
def validate_environment(self): # set up the environment required to run the model
|
def validate_environment(self): # set up the environment required to run the model
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -154,10 +154,12 @@ class OpenAITextCompletionConfig():
|
||||||
|
|
||||||
class OpenAIChatCompletion(BaseLLM):
|
class OpenAIChatCompletion(BaseLLM):
|
||||||
_client_session: httpx.Client
|
_client_session: httpx.Client
|
||||||
|
_aclient_session: httpx.AsyncClient
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._client_session = self.create_client_session()
|
self._client_session = self.create_client_session()
|
||||||
|
self._aclient_session = self.create_aclient_session()
|
||||||
|
|
||||||
def validate_environment(self, api_key):
|
def validate_environment(self, api_key):
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -251,15 +253,15 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
data: dict, headers: dict,
|
data: dict, headers: dict,
|
||||||
model_response: ModelResponse):
|
model_response: ModelResponse):
|
||||||
async with httpx.AsyncClient(timeout=600) as client:
|
client = self._aclient_session
|
||||||
response = await client.post(api_base, json=data, headers=headers)
|
|
||||||
response_json = response.json()
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
|
||||||
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
response = await client.post(api_base, json=data, headers=headers)
|
||||||
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
response_json = response.json()
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||||
|
|
||||||
def streaming(self,
|
def streaming(self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
@ -290,8 +292,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str):
|
model: str):
|
||||||
client = httpx.AsyncClient()
|
async with self._aclient_session.stream(
|
||||||
async with client.stream(
|
|
||||||
url=f"{api_base}",
|
url=f"{api_base}",
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
|
|
@ -6,11 +6,14 @@ import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
import litellm
|
import litellm
|
||||||
|
import httpx
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/")
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message
|
self.message
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
|
@ -108,6 +108,7 @@ class Completions():
|
||||||
response = completion(model=model, messages=messages, **self.params)
|
response = completion(model=model, messages=messages, **self.params)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@client
|
||||||
async def acompletion(*args, **kwargs):
|
async def acompletion(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||||
|
@ -149,63 +150,51 @@ async def acompletion(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
messages = args[1] if len(args) > 1 else kwargs["messages"]
|
|
||||||
### INITIALIZE LOGGING OBJECT ###
|
|
||||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
|
||||||
start_time = datetime.datetime.now()
|
|
||||||
logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time)
|
|
||||||
|
|
||||||
### PASS ARGS TO COMPLETION ###
|
### PASS ARGS TO COMPLETION ###
|
||||||
kwargs["litellm_logging_obj"] = logging_obj
|
|
||||||
kwargs["acompletion"] = True
|
kwargs["acompletion"] = True
|
||||||
kwargs["model"] = model
|
try:
|
||||||
kwargs["messages"] = messages
|
# Use a partial function to pass your keyword arguments
|
||||||
# Use a partial function to pass your keyword arguments
|
func = partial(completion, *args, **kwargs)
|
||||||
func = partial(completion, *args, **kwargs)
|
|
||||||
|
|
||||||
# Add the context to the function
|
# Add the context to the function
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
func_with_context = partial(ctx.run, func)
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
||||||
|
|
||||||
if (custom_llm_provider == "openai"
|
if (custom_llm_provider == "openai"
|
||||||
or custom_llm_provider == "azure"
|
or custom_llm_provider == "azure"
|
||||||
or custom_llm_provider == "custom_openai"
|
or custom_llm_provider == "custom_openai"
|
||||||
or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
if kwargs.get("stream", False):
|
if kwargs.get("stream", False):
|
||||||
response = completion(*args, **kwargs)
|
response = completion(*args, **kwargs)
|
||||||
else:
|
|
||||||
# Await normally
|
|
||||||
init_response = completion(*args, **kwargs)
|
|
||||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse):
|
|
||||||
response = init_response
|
|
||||||
else:
|
else:
|
||||||
response = await init_response
|
# Await normally
|
||||||
else:
|
init_response = completion(*args, **kwargs)
|
||||||
# Call the synchronous function using run_in_executor
|
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
response = init_response
|
||||||
if kwargs.get("stream", False): # return an async generator
|
else:
|
||||||
# do not change this
|
response = await init_response
|
||||||
# for stream = True, always return an async generator
|
else:
|
||||||
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
|
# Call the synchronous function using run_in_executor
|
||||||
# return response
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
return(
|
if kwargs.get("stream", False): # return an async generator
|
||||||
line
|
# do not change this
|
||||||
async for line in response
|
# for stream = True, always return an async generator
|
||||||
)
|
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
|
||||||
else:
|
# return response
|
||||||
end_time = datetime.datetime.now()
|
return(
|
||||||
# [OPTIONAL] ADD TO CACHE
|
line
|
||||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
async for line in response
|
||||||
litellm.cache.add_cache(response, *args, **kwargs)
|
)
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
## Map to OpenAI Exception
|
||||||
|
raise exception_type(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||||
|
)
|
||||||
|
|
||||||
# LOG SUCCESS
|
|
||||||
logging_obj.success_handler(response, start_time, end_time)
|
|
||||||
# RETURN RESULT
|
|
||||||
response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
|
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -1420,8 +1409,14 @@ def completion_with_retries(*args, **kwargs):
|
||||||
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
|
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
|
||||||
|
|
||||||
num_retries = kwargs.pop("num_retries", 3)
|
num_retries = kwargs.pop("num_retries", 3)
|
||||||
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
|
||||||
return retryer(completion, *args, **kwargs)
|
original_function = kwargs.pop("original_function", completion)
|
||||||
|
if retry_strategy == "constant_retry":
|
||||||
|
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||||
|
elif retry_strategy == "exponential_backoff_retry":
|
||||||
|
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||||
|
return retryer(original_function, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def batch_completion(
|
def batch_completion(
|
||||||
|
|
47
litellm/tests/test_loadtest_router.py
Normal file
47
litellm/tests/test_loadtest_router.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import copy
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
async def call_acompletion(semaphore, router: Router, input_data):
|
||||||
|
async with semaphore:
|
||||||
|
# Replace 'input_data' with appropriate parameters for acompletion
|
||||||
|
response = await router.acompletion(**input_data)
|
||||||
|
# Handle the response as needed
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Initialize the Router
|
||||||
|
model_list= [{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
router = Router(model_list=model_list, num_retries=3)
|
||||||
|
|
||||||
|
# Create a semaphore with a capacity of 100
|
||||||
|
semaphore = asyncio.Semaphore(100)
|
||||||
|
|
||||||
|
# List to hold all task references
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
# Launch 1000 tasks
|
||||||
|
for _ in range(1000):
|
||||||
|
task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all tasks to complete
|
||||||
|
responses = await asyncio.gather(*tasks)
|
||||||
|
# Process responses as needed
|
||||||
|
|
||||||
|
# Run the main function
|
||||||
|
asyncio.run(main())
|
|
@ -317,7 +317,7 @@ def test_function_calling_on_router():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An exception occurred: {e}")
|
print(f"An exception occurred: {e}")
|
||||||
|
|
||||||
test_function_calling_on_router()
|
# test_function_calling_on_router()
|
||||||
|
|
||||||
def test_aembedding_on_router():
|
def test_aembedding_on_router():
|
||||||
try:
|
try:
|
||||||
|
|
180
litellm/utils.py
180
litellm/utils.py
|
@ -456,6 +456,7 @@ from enum import Enum
|
||||||
class CallTypes(Enum):
|
class CallTypes(Enum):
|
||||||
embedding = 'embedding'
|
embedding = 'embedding'
|
||||||
completion = 'completion'
|
completion = 'completion'
|
||||||
|
acompletion = 'acompletion'
|
||||||
|
|
||||||
# Logging function -> log the exact model details + what's being sent | Non-Blocking
|
# Logging function -> log the exact model details + what's being sent | Non-Blocking
|
||||||
class Logging:
|
class Logging:
|
||||||
|
@ -984,7 +985,7 @@ def exception_logging(
|
||||||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||||
def client(original_function):
|
def client(original_function):
|
||||||
global liteDebuggerClient, get_all_keys
|
global liteDebuggerClient, get_all_keys
|
||||||
|
import inspect
|
||||||
def function_setup(
|
def function_setup(
|
||||||
start_time, *args, **kwargs
|
start_time, *args, **kwargs
|
||||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||||
|
@ -1036,7 +1037,7 @@ def client(original_function):
|
||||||
# INIT LOGGER - for user-specified integrations
|
# INIT LOGGER - for user-specified integrations
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
if call_type == CallTypes.completion.value:
|
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
|
||||||
if len(args) > 1:
|
if len(args) > 1:
|
||||||
messages = args[1]
|
messages = args[1]
|
||||||
elif kwargs.get("messages", None):
|
elif kwargs.get("messages", None):
|
||||||
|
@ -1183,7 +1184,107 @@ def client(original_function):
|
||||||
): # make it easy to get to the debugger logs if you've initialized it
|
): # make it easy to get to the debugger logs if you've initialized it
|
||||||
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
|
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
|
||||||
raise e
|
raise e
|
||||||
return wrapper
|
|
||||||
|
async def wrapper_async(*args, **kwargs):
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
result = None
|
||||||
|
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||||
|
|
||||||
|
# only set litellm_call_id if its not in kwargs
|
||||||
|
if "litellm_call_id" not in kwargs:
|
||||||
|
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||||
|
try:
|
||||||
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
|
except:
|
||||||
|
raise ValueError("model param not passed in.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if logging_obj is None:
|
||||||
|
logging_obj = function_setup(start_time, *args, **kwargs)
|
||||||
|
kwargs["litellm_logging_obj"] = logging_obj
|
||||||
|
|
||||||
|
# [OPTIONAL] CHECK BUDGET
|
||||||
|
if litellm.max_budget:
|
||||||
|
if litellm._current_cost > litellm.max_budget:
|
||||||
|
raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget)
|
||||||
|
|
||||||
|
# [OPTIONAL] CHECK CACHE
|
||||||
|
print_verbose(f"litellm.cache: {litellm.cache}")
|
||||||
|
print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}")
|
||||||
|
# if caching is false, don't run this
|
||||||
|
if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function
|
||||||
|
# checking cache
|
||||||
|
if (litellm.cache != None):
|
||||||
|
print_verbose(f"Checking Cache")
|
||||||
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
|
if cached_result != None:
|
||||||
|
print_verbose(f"Cache Hit!")
|
||||||
|
call_type = original_function.__name__
|
||||||
|
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
|
||||||
|
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||||
|
else:
|
||||||
|
return cached_result
|
||||||
|
# MODEL CALL
|
||||||
|
result = original_function(*args, **kwargs)
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
if "stream" in kwargs and kwargs["stream"] == True:
|
||||||
|
if "complete_response" in kwargs and kwargs["complete_response"] == True:
|
||||||
|
chunks = []
|
||||||
|
for idx, chunk in enumerate(result):
|
||||||
|
chunks.append(chunk)
|
||||||
|
return litellm.stream_chunk_builder(chunks)
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
result = await result
|
||||||
|
# [OPTIONAL] ADD TO CACHE
|
||||||
|
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||||
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
|
|
||||||
|
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||||
|
logging_obj.success_handler(result, start_time, end_time)
|
||||||
|
# RETURN RESULT
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
call_type = original_function.__name__
|
||||||
|
if call_type == CallTypes.acompletion.value:
|
||||||
|
num_retries = (
|
||||||
|
kwargs.get("num_retries", None)
|
||||||
|
or litellm.num_retries
|
||||||
|
or None
|
||||||
|
)
|
||||||
|
litellm.num_retries = None # set retries to None to prevent infinite loops
|
||||||
|
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {})
|
||||||
|
|
||||||
|
if num_retries:
|
||||||
|
kwargs["num_retries"] = num_retries
|
||||||
|
kwargs["original_function"] = original_function
|
||||||
|
if (isinstance(e, openai.RateLimitError)): # rate limiting specific error
|
||||||
|
kwargs["retry_strategy"] = "exponential_backoff_retry"
|
||||||
|
elif (isinstance(e, openai.APIError)): # generic api error
|
||||||
|
kwargs["retry_strategy"] = "constant_retry"
|
||||||
|
return litellm.completion_with_retries(*args, **kwargs)
|
||||||
|
elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict:
|
||||||
|
if len(args) > 0:
|
||||||
|
args[0] = context_window_fallback_dict[model]
|
||||||
|
else:
|
||||||
|
kwargs["model"] = context_window_fallback_dict[model]
|
||||||
|
return original_function(*args, **kwargs)
|
||||||
|
traceback_exception = traceback.format_exc()
|
||||||
|
crash_reporting(*args, **kwargs, exception=traceback_exception)
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||||
|
if logging_obj:
|
||||||
|
threading.Thread(target=logging_obj.failure_handler, args=(e, traceback_exception, start_time, end_time)).start()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Use httpx to determine if the original function is a coroutine
|
||||||
|
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||||
|
|
||||||
|
# Return the appropriate wrapper based on the original function type
|
||||||
|
if is_coroutine:
|
||||||
|
return wrapper_async
|
||||||
|
else:
|
||||||
|
return wrapper
|
||||||
|
|
||||||
####### USAGE CALCULATOR ################
|
####### USAGE CALCULATOR ################
|
||||||
|
|
||||||
|
@ -3116,31 +3217,13 @@ def exception_type(
|
||||||
print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa
|
print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa
|
||||||
print() # noqa
|
print() # noqa
|
||||||
try:
|
try:
|
||||||
if isinstance(original_exception, OriginalError):
|
if model:
|
||||||
# Handle the OpenAIError
|
|
||||||
exception_mapping_worked = True
|
|
||||||
if custom_llm_provider == "openrouter":
|
|
||||||
if original_exception.http_status == 413:
|
|
||||||
raise BadRequestError(
|
|
||||||
message=str(original_exception),
|
|
||||||
model=model,
|
|
||||||
llm_provider="openrouter"
|
|
||||||
)
|
|
||||||
original_exception.llm_provider = "openrouter"
|
|
||||||
if "This model's maximum context length is" in original_exception._message:
|
|
||||||
raise ContextWindowExceededError(
|
|
||||||
message=str(original_exception),
|
|
||||||
model=model,
|
|
||||||
llm_provider=original_exception.llm_provider
|
|
||||||
)
|
|
||||||
raise original_exception
|
|
||||||
elif model:
|
|
||||||
error_str = str(original_exception)
|
error_str = str(original_exception)
|
||||||
if isinstance(original_exception, BaseException):
|
if isinstance(original_exception, BaseException):
|
||||||
exception_type = type(original_exception).__name__
|
exception_type = type(original_exception).__name__
|
||||||
else:
|
else:
|
||||||
exception_type = ""
|
exception_type = ""
|
||||||
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai":
|
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai":
|
||||||
if "This model's maximum context length is" in error_str or "Request too large" in error_str:
|
if "This model's maximum context length is" in error_str or "Request too large" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise ContextWindowExceededError(
|
raise ContextWindowExceededError(
|
||||||
|
@ -3191,6 +3274,14 @@ def exception_type(
|
||||||
llm_provider="openai",
|
llm_provider="openai",
|
||||||
response=original_exception.response
|
response=original_exception.response
|
||||||
)
|
)
|
||||||
|
elif original_exception.status_code == 503:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ServiceUnavailableError(
|
||||||
|
message=f"OpenAIException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="openai",
|
||||||
|
response=original_exception.response
|
||||||
|
)
|
||||||
elif original_exception.status_code == 504: # gateway timeout error
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise Timeout(
|
raise Timeout(
|
||||||
|
@ -3968,49 +4059,6 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request
|
request=original_exception.request
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk":
|
|
||||||
if hasattr(original_exception, "status_code"):
|
|
||||||
exception_mapping_worked = True
|
|
||||||
if original_exception.status_code == 401:
|
|
||||||
exception_mapping_worked = True
|
|
||||||
raise AuthenticationError(
|
|
||||||
message=f"CustomOpenAIException - {original_exception.message}",
|
|
||||||
llm_provider="custom_openai",
|
|
||||||
model=model
|
|
||||||
)
|
|
||||||
elif original_exception.status_code == 408:
|
|
||||||
exception_mapping_worked = True
|
|
||||||
raise Timeout(
|
|
||||||
message=f"CustomOpenAIException - {original_exception.message}",
|
|
||||||
model=model,
|
|
||||||
llm_provider="custom_openai",
|
|
||||||
request=original_exception.request
|
|
||||||
)
|
|
||||||
if original_exception.status_code == 422:
|
|
||||||
exception_mapping_worked = True
|
|
||||||
raise BadRequestError(
|
|
||||||
message=f"CustomOpenAIException - {original_exception.message}",
|
|
||||||
model=model,
|
|
||||||
llm_provider="custom_openai",
|
|
||||||
response=original_exception.response
|
|
||||||
)
|
|
||||||
elif original_exception.status_code == 429:
|
|
||||||
exception_mapping_worked = True
|
|
||||||
raise RateLimitError(
|
|
||||||
message=f"CustomOpenAIException - {original_exception.message}",
|
|
||||||
model=model,
|
|
||||||
llm_provider="custom_openai",
|
|
||||||
response=original_exception.response
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
exception_mapping_worked = True
|
|
||||||
raise APIError(
|
|
||||||
status_code=original_exception.status_code,
|
|
||||||
message=f"CustomOpenAIException - {original_exception.message}",
|
|
||||||
llm_provider="custom_openai",
|
|
||||||
model=model,
|
|
||||||
request=original_exception.request
|
|
||||||
)
|
|
||||||
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
|
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue