mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #3962 from BerriAI/litellm_return_num_rets_max_exceptions
[Feat] return `num_retries` and `max_retries` in exceptions
This commit is contained in:
commit
fb49d036fb
4 changed files with 355 additions and 9 deletions
|
@ -803,6 +803,7 @@ from .exceptions import (
|
|||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
UnprocessableEntityError,
|
||||
LITELLM_EXCEPTION_TYPES,
|
||||
)
|
||||
from .budget_manager import BudgetManager
|
||||
from .proxy.proxy_cli import run_server
|
||||
|
|
|
@ -22,16 +22,36 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
|
|||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 401
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# raise when invalid models passed, example gpt-8
|
||||
class NotFoundError(openai.NotFoundError): # type: ignore
|
||||
|
@ -42,16 +62,36 @@ class NotFoundError(openai.NotFoundError): # type: ignore
|
|||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 404
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class BadRequestError(openai.BadRequestError): # type: ignore
|
||||
def __init__(
|
||||
|
@ -61,6 +101,8 @@ class BadRequestError(openai.BadRequestError): # type: ignore
|
|||
llm_provider,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
|
@ -73,10 +115,28 @@ class BadRequestError(openai.BadRequestError): # type: ignore
|
|||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
|
||||
def __init__(
|
||||
|
@ -86,20 +146,46 @@ class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
|
|||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 422
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class Timeout(openai.APITimeoutError): # type: ignore
|
||||
def __init__(
|
||||
self, message, model, llm_provider, litellm_debug_info: Optional[str] = None
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
super().__init__(
|
||||
|
@ -110,10 +196,25 @@ class Timeout(openai.APITimeoutError): # type: ignore
|
|||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
|
||||
# custom function to convert to str
|
||||
def __str__(self):
|
||||
return str(self.message)
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
|
||||
|
@ -124,16 +225,36 @@ class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
|
|||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 403
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class RateLimitError(openai.RateLimitError): # type: ignore
|
||||
def __init__(
|
||||
|
@ -143,16 +264,36 @@ class RateLimitError(openai.RateLimitError): # type: ignore
|
|||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 429
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.modle = model
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
|
||||
class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||
|
@ -176,6 +317,22 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
|||
response=response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
|
||||
class RejectedRequestError(BadRequestError): # type: ignore
|
||||
|
@ -202,6 +359,22 @@ class RejectedRequestError(BadRequestError): # type: ignore
|
|||
response=response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
||||
|
@ -225,6 +398,22 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
|
|||
response=response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
||||
def __init__(
|
||||
|
@ -234,16 +423,36 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
|||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 503
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
||||
class APIError(openai.APIError): # type: ignore
|
||||
|
@ -255,14 +464,34 @@ class APIError(openai.APIError): # type: ignore
|
|||
model,
|
||||
request: httpx.Request,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIConnectionError(openai.APIConnectionError): # type: ignore
|
||||
|
@ -273,19 +502,45 @@ class APIConnectionError(openai.APIConnectionError): # type: ignore
|
|||
model,
|
||||
request: httpx.Request,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.status_code = 500
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(message=self.message, request=request)
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
|
||||
def __init__(
|
||||
self, message, llm_provider, model, litellm_debug_info: Optional[str] = None
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
|
@ -293,8 +548,26 @@ class APIResponseValidationError(openai.APIResponseValidationError): # type: ig
|
|||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
response = httpx.Response(status_code=500, request=request)
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(response=response, body=None, message=message)
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class OpenAIError(openai.OpenAIError): # type: ignore
|
||||
def __init__(self, original_exception):
|
||||
|
@ -309,6 +582,25 @@ class OpenAIError(openai.OpenAIError): # type: ignore
|
|||
self.llm_provider = "openai"
|
||||
|
||||
|
||||
LITELLM_EXCEPTION_TYPES = [
|
||||
AuthenticationError,
|
||||
NotFoundError,
|
||||
BadRequestError,
|
||||
UnprocessableEntityError,
|
||||
Timeout,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
ContextWindowExceededError,
|
||||
RejectedRequestError,
|
||||
ContentPolicyViolationError,
|
||||
ServiceUnavailableError,
|
||||
APIError,
|
||||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
OpenAIError,
|
||||
]
|
||||
|
||||
|
||||
class BudgetExceededError(Exception):
|
||||
def __init__(self, current_cost, max_budget):
|
||||
self.current_cost = current_cost
|
||||
|
|
|
@ -2059,6 +2059,8 @@ class Router:
|
|||
response = await original_function(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
num_retries = None
|
||||
current_attempt = None
|
||||
original_exception = e
|
||||
"""
|
||||
Retry Logic
|
||||
|
@ -2128,11 +2130,10 @@ class Router:
|
|||
)
|
||||
await asyncio.sleep(_timeout)
|
||||
|
||||
try:
|
||||
cooldown_deployments = await self._async_get_cooldown_deployments()
|
||||
original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}"
|
||||
except:
|
||||
pass
|
||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
original_exception.max_retries = num_retries
|
||||
original_exception.num_retries = current_attempt
|
||||
|
||||
raise original_exception
|
||||
|
||||
def should_retry_this_error(
|
||||
|
@ -2333,6 +2334,8 @@ class Router:
|
|||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
num_retries = None
|
||||
current_attempt = None
|
||||
original_exception = e
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
_healthy_deployments = self._get_healthy_deployments(
|
||||
|
@ -2383,6 +2386,11 @@ class Router:
|
|||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
time.sleep(_timeout)
|
||||
|
||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
original_exception.max_retries = num_retries
|
||||
original_exception.num_retries = current_attempt
|
||||
|
||||
raise original_exception
|
||||
|
||||
### HELPER FUNCTIONS
|
||||
|
|
|
@ -186,6 +186,51 @@ async def test_router_retry_policy(error_type):
|
|||
assert customHandler.previous_models == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(
|
||||
reason="This is a local only test, use this to confirm if retry policy works"
|
||||
)
|
||||
async def test_router_retry_policy_on_429_errprs():
|
||||
from litellm.router import RetryPolicy
|
||||
|
||||
retry_policy = RetryPolicy(
|
||||
RateLimitErrorRetries=2,
|
||||
)
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": {
|
||||
"model": "vertex_ai/gemini-1.5-pro-001",
|
||||
},
|
||||
},
|
||||
],
|
||||
retry_policy=retry_policy,
|
||||
# set_verbose=True,
|
||||
# debug_level="DEBUG",
|
||||
allowed_fails=10,
|
||||
)
|
||||
|
||||
customHandler = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
try:
|
||||
# litellm.set_verbose = True
|
||||
_one_message = [{"role": "user", "content": "Hello good morning"}]
|
||||
|
||||
messages = [_one_message] * 5
|
||||
print("messages: ", messages)
|
||||
responses = await router.abatch_completion(
|
||||
models=["gpt-3.5-turbo"],
|
||||
messages=messages,
|
||||
)
|
||||
print("responses: ", responses)
|
||||
except Exception as e:
|
||||
print("got an exception", e)
|
||||
pass
|
||||
asyncio.sleep(0.05)
|
||||
print("customHandler.previous_models: ", customHandler.previous_models)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_group", ["gpt-3.5-turbo", "bad-model"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_router_retry_policy(model_group):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue