mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(main.py): keep client consistent across calls + exponential backoff retry on ratelimit errors
This commit is contained in:
parent
5963d9d283
commit
a7222f257c
9 changed files with 239 additions and 131 deletions
|
@ -19,6 +19,7 @@ telemetry = True
|
|||
max_tokens = 256 # OpenAI Defaults
|
||||
drop_params = False
|
||||
retry = True
|
||||
request_timeout: float = 600
|
||||
api_key: Optional[str] = None
|
||||
openai_key: Optional[str] = None
|
||||
azure_key: Optional[str] = None
|
||||
|
@ -46,6 +47,7 @@ _current_cost = 0 # private variable, used if max budget is set
|
|||
error_logs: Dict = {}
|
||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||
client_session: Optional[httpx.Client] = None
|
||||
aclient_session: Optional[httpx.AsyncClient] = None
|
||||
model_fallbacks: Optional[List] = None
|
||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
num_retries: Optional[int] = None
|
||||
|
|
|
@ -105,7 +105,7 @@ class APIError(APIError): # type: ignore
|
|||
super().__init__(
|
||||
self.message,
|
||||
request=request, # type: ignore
|
||||
body=None,
|
||||
body=None
|
||||
)
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
|
|
|
@ -9,14 +9,26 @@ class BaseLLM:
|
|||
if litellm.client_session:
|
||||
_client_session = litellm.client_session
|
||||
else:
|
||||
_client_session = httpx.Client(timeout=600)
|
||||
_client_session = httpx.Client(timeout=litellm.request_timeout)
|
||||
|
||||
return _client_session
|
||||
|
||||
def create_aclient_session(self):
|
||||
if litellm.aclient_session:
|
||||
_aclient_session = litellm.aclient_session
|
||||
else:
|
||||
_aclient_session = httpx.AsyncClient(timeout=litellm.request_timeout)
|
||||
|
||||
return _aclient_session
|
||||
|
||||
def __exit__(self):
|
||||
if hasattr(self, '_client_session'):
|
||||
self._client_session.close()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if hasattr(self, '_aclient_session'):
|
||||
await self._aclient_session.aclose()
|
||||
|
||||
def validate_environment(self): # set up the environment required to run the model
|
||||
pass
|
||||
|
||||
|
|
|
@ -154,10 +154,12 @@ class OpenAITextCompletionConfig():
|
|||
|
||||
class OpenAIChatCompletion(BaseLLM):
|
||||
_client_session: httpx.Client
|
||||
_aclient_session: httpx.AsyncClient
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._client_session = self.create_client_session()
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
|
||||
def validate_environment(self, api_key):
|
||||
headers = {
|
||||
|
@ -251,13 +253,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: str,
|
||||
data: dict, headers: dict,
|
||||
model_response: ModelResponse):
|
||||
async with httpx.AsyncClient(timeout=600) as client:
|
||||
client = self._aclient_session
|
||||
|
||||
response = await client.post(api_base, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
||||
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||
|
||||
|
@ -290,8 +292,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str):
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
async with self._aclient_session.stream(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
|
|
|
@ -6,11 +6,14 @@ import time
|
|||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
import litellm
|
||||
import httpx
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
|
|
@ -108,6 +108,7 @@ class Completions():
|
|||
response = completion(model=model, messages=messages, **self.params)
|
||||
return response
|
||||
|
||||
@client
|
||||
async def acompletion(*args, **kwargs):
|
||||
"""
|
||||
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||
|
@ -149,17 +150,9 @@ async def acompletion(*args, **kwargs):
|
|||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
messages = args[1] if len(args) > 1 else kwargs["messages"]
|
||||
### INITIALIZE LOGGING OBJECT ###
|
||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||
start_time = datetime.datetime.now()
|
||||
logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time)
|
||||
|
||||
### PASS ARGS TO COMPLETION ###
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
kwargs["acompletion"] = True
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
try:
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(completion, *args, **kwargs)
|
||||
|
||||
|
@ -178,7 +171,7 @@ async def acompletion(*args, **kwargs):
|
|||
else:
|
||||
# Await normally
|
||||
init_response = completion(*args, **kwargs)
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse):
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
else:
|
||||
response = await init_response
|
||||
|
@ -195,17 +188,13 @@ async def acompletion(*args, **kwargs):
|
|||
async for line in response
|
||||
)
|
||||
else:
|
||||
end_time = datetime.datetime.now()
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(response, *args, **kwargs)
|
||||
|
||||
# LOG SUCCESS
|
||||
logging_obj.success_handler(response, start_time, end_time)
|
||||
# RETURN RESULT
|
||||
response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
## Map to OpenAI Exception
|
||||
raise exception_type(
|
||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||
)
|
||||
|
||||
|
||||
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
|
||||
"""
|
||||
|
@ -1420,8 +1409,14 @@ def completion_with_retries(*args, **kwargs):
|
|||
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
|
||||
|
||||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "constant_retry":
|
||||
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||
return retryer(completion, *args, **kwargs)
|
||||
elif retry_strategy == "exponential_backoff_retry":
|
||||
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||
return retryer(original_function, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
def batch_completion(
|
||||
|
|
47
litellm/tests/test_loadtest_router.py
Normal file
47
litellm/tests/test_loadtest_router.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
import copy
|
||||
|
||||
load_dotenv()
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import asyncio
|
||||
from litellm import Router
|
||||
|
||||
async def call_acompletion(semaphore, router: Router, input_data):
|
||||
async with semaphore:
|
||||
# Replace 'input_data' with appropriate parameters for acompletion
|
||||
response = await router.acompletion(**input_data)
|
||||
# Handle the response as needed
|
||||
return response
|
||||
|
||||
async def main():
|
||||
# Initialize the Router
|
||||
model_list= [{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
}]
|
||||
router = Router(model_list=model_list, num_retries=3)
|
||||
|
||||
# Create a semaphore with a capacity of 100
|
||||
semaphore = asyncio.Semaphore(100)
|
||||
|
||||
# List to hold all task references
|
||||
tasks = []
|
||||
|
||||
# Launch 1000 tasks
|
||||
for _ in range(1000):
|
||||
task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
responses = await asyncio.gather(*tasks)
|
||||
# Process responses as needed
|
||||
|
||||
# Run the main function
|
||||
asyncio.run(main())
|
|
@ -317,7 +317,7 @@ def test_function_calling_on_router():
|
|||
except Exception as e:
|
||||
print(f"An exception occurred: {e}")
|
||||
|
||||
test_function_calling_on_router()
|
||||
# test_function_calling_on_router()
|
||||
|
||||
def test_aembedding_on_router():
|
||||
try:
|
||||
|
|
178
litellm/utils.py
178
litellm/utils.py
|
@ -456,6 +456,7 @@ from enum import Enum
|
|||
class CallTypes(Enum):
|
||||
embedding = 'embedding'
|
||||
completion = 'completion'
|
||||
acompletion = 'acompletion'
|
||||
|
||||
# Logging function -> log the exact model details + what's being sent | Non-Blocking
|
||||
class Logging:
|
||||
|
@ -984,7 +985,7 @@ def exception_logging(
|
|||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||
def client(original_function):
|
||||
global liteDebuggerClient, get_all_keys
|
||||
|
||||
import inspect
|
||||
def function_setup(
|
||||
start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
@ -1036,7 +1037,7 @@ def client(original_function):
|
|||
# INIT LOGGER - for user-specified integrations
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.completion.value:
|
||||
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
|
||||
if len(args) > 1:
|
||||
messages = args[1]
|
||||
elif kwargs.get("messages", None):
|
||||
|
@ -1183,6 +1184,106 @@ def client(original_function):
|
|||
): # make it easy to get to the debugger logs if you've initialized it
|
||||
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
|
||||
raise e
|
||||
|
||||
async def wrapper_async(*args, **kwargs):
|
||||
start_time = datetime.datetime.now()
|
||||
result = None
|
||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||
|
||||
# only set litellm_call_id if its not in kwargs
|
||||
if "litellm_call_id" not in kwargs:
|
||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||
try:
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
except:
|
||||
raise ValueError("model param not passed in.")
|
||||
|
||||
try:
|
||||
if logging_obj is None:
|
||||
logging_obj = function_setup(start_time, *args, **kwargs)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
|
||||
# [OPTIONAL] CHECK BUDGET
|
||||
if litellm.max_budget:
|
||||
if litellm._current_cost > litellm.max_budget:
|
||||
raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget)
|
||||
|
||||
# [OPTIONAL] CHECK CACHE
|
||||
print_verbose(f"litellm.cache: {litellm.cache}")
|
||||
print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}")
|
||||
# if caching is false, don't run this
|
||||
if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function
|
||||
# checking cache
|
||||
if (litellm.cache != None):
|
||||
print_verbose(f"Checking Cache")
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result != None:
|
||||
print_verbose(f"Cache Hit!")
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
|
||||
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
|
||||
else:
|
||||
return cached_result
|
||||
# MODEL CALL
|
||||
result = original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
if "stream" in kwargs and kwargs["stream"] == True:
|
||||
if "complete_response" in kwargs and kwargs["complete_response"] == True:
|
||||
chunks = []
|
||||
for idx, chunk in enumerate(result):
|
||||
chunks.append(chunk)
|
||||
return litellm.stream_chunk_builder(chunks)
|
||||
else:
|
||||
return result
|
||||
result = await result
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||
logging_obj.success_handler(result, start_time, end_time)
|
||||
# RETURN RESULT
|
||||
return result
|
||||
except Exception as e:
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value:
|
||||
num_retries = (
|
||||
kwargs.get("num_retries", None)
|
||||
or litellm.num_retries
|
||||
or None
|
||||
)
|
||||
litellm.num_retries = None # set retries to None to prevent infinite loops
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {})
|
||||
|
||||
if num_retries:
|
||||
kwargs["num_retries"] = num_retries
|
||||
kwargs["original_function"] = original_function
|
||||
if (isinstance(e, openai.RateLimitError)): # rate limiting specific error
|
||||
kwargs["retry_strategy"] = "exponential_backoff_retry"
|
||||
elif (isinstance(e, openai.APIError)): # generic api error
|
||||
kwargs["retry_strategy"] = "constant_retry"
|
||||
return litellm.completion_with_retries(*args, **kwargs)
|
||||
elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict:
|
||||
if len(args) > 0:
|
||||
args[0] = context_window_fallback_dict[model]
|
||||
else:
|
||||
kwargs["model"] = context_window_fallback_dict[model]
|
||||
return original_function(*args, **kwargs)
|
||||
traceback_exception = traceback.format_exc()
|
||||
crash_reporting(*args, **kwargs, exception=traceback_exception)
|
||||
end_time = datetime.datetime.now()
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
if logging_obj:
|
||||
threading.Thread(target=logging_obj.failure_handler, args=(e, traceback_exception, start_time, end_time)).start()
|
||||
raise e
|
||||
|
||||
# Use httpx to determine if the original function is a coroutine
|
||||
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||
|
||||
# Return the appropriate wrapper based on the original function type
|
||||
if is_coroutine:
|
||||
return wrapper_async
|
||||
else:
|
||||
return wrapper
|
||||
|
||||
####### USAGE CALCULATOR ################
|
||||
|
@ -3116,31 +3217,13 @@ def exception_type(
|
|||
print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa
|
||||
print() # noqa
|
||||
try:
|
||||
if isinstance(original_exception, OriginalError):
|
||||
# Handle the OpenAIError
|
||||
exception_mapping_worked = True
|
||||
if custom_llm_provider == "openrouter":
|
||||
if original_exception.http_status == 413:
|
||||
raise BadRequestError(
|
||||
message=str(original_exception),
|
||||
model=model,
|
||||
llm_provider="openrouter"
|
||||
)
|
||||
original_exception.llm_provider = "openrouter"
|
||||
if "This model's maximum context length is" in original_exception._message:
|
||||
raise ContextWindowExceededError(
|
||||
message=str(original_exception),
|
||||
model=model,
|
||||
llm_provider=original_exception.llm_provider
|
||||
)
|
||||
raise original_exception
|
||||
elif model:
|
||||
if model:
|
||||
error_str = str(original_exception)
|
||||
if isinstance(original_exception, BaseException):
|
||||
exception_type = type(original_exception).__name__
|
||||
else:
|
||||
exception_type = ""
|
||||
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai":
|
||||
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai":
|
||||
if "This model's maximum context length is" in error_str or "Request too large" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise ContextWindowExceededError(
|
||||
|
@ -3191,6 +3274,14 @@ def exception_type(
|
|||
llm_provider="openai",
|
||||
response=original_exception.response
|
||||
)
|
||||
elif original_exception.status_code == 503:
|
||||
exception_mapping_worked = True
|
||||
raise ServiceUnavailableError(
|
||||
message=f"OpenAIException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="openai",
|
||||
response=original_exception.response
|
||||
)
|
||||
elif original_exception.status_code == 504: # gateway timeout error
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
|
@ -3968,49 +4059,6 @@ def exception_type(
|
|||
model=model,
|
||||
request=original_exception.request
|
||||
)
|
||||
elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk":
|
||||
if hasattr(original_exception, "status_code"):
|
||||
exception_mapping_worked = True
|
||||
if original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=f"CustomOpenAIException - {original_exception.message}",
|
||||
llm_provider="custom_openai",
|
||||
model=model
|
||||
)
|
||||
elif original_exception.status_code == 408:
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"CustomOpenAIException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="custom_openai",
|
||||
request=original_exception.request
|
||||
)
|
||||
if original_exception.status_code == 422:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"CustomOpenAIException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="custom_openai",
|
||||
response=original_exception.response
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"CustomOpenAIException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="custom_openai",
|
||||
response=original_exception.response
|
||||
)
|
||||
else:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
status_code=original_exception.status_code,
|
||||
message=f"CustomOpenAIException - {original_exception.message}",
|
||||
llm_provider="custom_openai",
|
||||
model=model,
|
||||
request=original_exception.request
|
||||
)
|
||||
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue