fix(router.py): fix should_retry logic for authentication errors

This commit is contained in:
Krrish Dholakia 2024-06-03 13:12:00 -07:00
parent 872cd2d8a0
commit a019fd05e3
2 changed files with 13 additions and 8 deletions

View file

@ -2154,7 +2154,6 @@ class Router:
- there are no fallbacks - there are no fallbacks
- there are no healthy deployments in the same model group - there are no healthy deployments in the same model group
""" """
_num_healthy_deployments = 0 _num_healthy_deployments = 0
if healthy_deployments is not None and isinstance(healthy_deployments, list): if healthy_deployments is not None and isinstance(healthy_deployments, list):
_num_healthy_deployments = len(healthy_deployments) _num_healthy_deployments = len(healthy_deployments)
@ -2167,15 +2166,21 @@ class Router:
raise error raise error
# Error we should only retry if there are other deployments # Error we should only retry if there are other deployments
if isinstance(error, openai.RateLimitError) or isinstance( if isinstance(error, openai.RateLimitError):
error, openai.AuthenticationError
):
if ( if (
_num_healthy_deployments <= 0 _num_healthy_deployments <= 0 # if no healthy deployments
and regular_fallbacks is not None and regular_fallbacks is not None # and fallbacks available
and len(regular_fallbacks) > 0 and len(regular_fallbacks) > 0
): ):
raise error raise error # then raise the error
if isinstance(error, openai.AuthenticationError):
"""
- if other deployments available -> retry
- else -> raise error
"""
if _num_healthy_deployments <= 0: # if no healthy deployments
raise error # then raise error
return True return True

View file

@ -60,7 +60,7 @@ Test sync + async
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize("error_type", ["Authorization Error", "API Error"]) @pytest.mark.parametrize("error_type", ["API Error", "Authorization Error"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_router_retries_errors(sync_mode, error_type): async def test_router_retries_errors(sync_mode, error_type):
""" """