fix(router.py): fix sync should_retry logic

This commit is contained in:
Krrish Dholakia 2024-04-27 14:48:07 -07:00
parent f31d3c4e9f
commit 5e0bd5982e
2 changed files with 103 additions and 53 deletions

View file

@ -1453,6 +1453,7 @@ class Router:
await asyncio.sleep(timeout)
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
raise e # don't wait to retry if deployment hits user-defined rate-limit
elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code
):
@ -1614,6 +1615,38 @@ class Router:
raise e
raise original_exception
def _router_should_retry(
self, e: Exception, remaining_retries: int, num_retries: int
):
if "No models available" in str(e):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
time.sleep(timeout)
elif (
hasattr(e, "status_code")
and hasattr(e, "response")
and litellm._should_retry(status_code=e.status_code)
):
if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
response_headers=e.response.headers,
min_timeout=self.retry_after,
)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
time.sleep(timeout)
else:
raise e
def function_with_retries(self, *args, **kwargs):
"""
Try calling the model 3 times. Shuffle between available deployments.
@ -1649,6 +1682,11 @@ class Router:
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
### RETRY
self._router_should_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
)
for current_attempt in range(num_retries):
verbose_router_logger.debug(
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
@ -1662,34 +1700,11 @@ class Router:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt
if "No models available" in str(e):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
time.sleep(timeout)
elif (
hasattr(e, "status_code")
and hasattr(e, "response")
and litellm._should_retry(status_code=e.status_code)
):
if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
response_headers=e.response.headers,
min_timeout=self.retry_after,
)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
time.sleep(timeout)
else:
raise e
self._router_should_retry(
e=e,
remaining_retries=remaining_retries,
num_retries=num_retries,
)
raise original_exception
### HELPER FUNCTIONS

View file

@ -396,7 +396,9 @@ def test_router_init_gpt_4_vision_enhancements():
pytest.fail(f"Error occurred: {e}")
def test_openai_with_organization():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_openai_with_organization(sync_mode):
try:
print("Testing OpenAI with organization")
model_list = [
@ -418,32 +420,65 @@ def test_openai_with_organization():
print(router.model_list)
print(router.model_list[0])
openai_client = router._get_client(
deployment=router.model_list[0],
kwargs={"input": ["hello"], "model": "openai-bad-org"},
)
print(vars(openai_client))
assert openai_client.organization == "org-ikDc4ex8NB"
# bad org raises error
try:
response = router.completion(
model="openai-bad-org",
messages=[{"role": "user", "content": "this is a test"}],
if sync_mode:
openai_client = router._get_client(
deployment=router.model_list[0],
kwargs={"input": ["hello"], "model": "openai-bad-org"},
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
print("Got exception: " + str(e))
assert "No such organization: org-ikDc4ex8NB" in str(e)
print(vars(openai_client))
# good org works
response = router.completion(
model="openai-good-org",
messages=[{"role": "user", "content": "this is a test"}],
max_tokens=5,
)
assert openai_client.organization == "org-ikDc4ex8NB"
# bad org raises error
try:
response = router.completion(
model="openai-bad-org",
messages=[{"role": "user", "content": "this is a test"}],
)
pytest.fail(
"Request should have failed - This organization does not exist"
)
except Exception as e:
print("Got exception: " + str(e))
assert "No such organization: org-ikDc4ex8NB" in str(e)
# good org works
response = router.completion(
model="openai-good-org",
messages=[{"role": "user", "content": "this is a test"}],
max_tokens=5,
)
else:
openai_client = router._get_client(
deployment=router.model_list[0],
kwargs={"input": ["hello"], "model": "openai-bad-org"},
client_type="async",
)
print(vars(openai_client))
assert openai_client.organization == "org-ikDc4ex8NB"
# bad org raises error
try:
response = await router.acompletion(
model="openai-bad-org",
messages=[{"role": "user", "content": "this is a test"}],
)
pytest.fail(
"Request should have failed - This organization does not exist"
)
except Exception as e:
print("Got exception: " + str(e))
assert "No such organization: org-ikDc4ex8NB" in str(e)
# good org works
response = await router.acompletion(
model="openai-good-org",
messages=[{"role": "user", "content": "this is a test"}],
max_tokens=5,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")