This commit is contained in:
swiftdevil 2025-04-24 00:57:58 -07:00 committed by GitHub
commit 44ad713078
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 2 deletions

View file

@ -3239,13 +3239,13 @@ async def acompletion_with_retries(*args, **kwargs):
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(
retryer = tenacity.AsyncRetrying(
wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(num_retries),
reraise=True,
)
else:
retryer = tenacity.Retrying(
retryer = tenacity.AsyncRetrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
return await retryer(original_function, *args, **kwargs)

View file

@ -146,3 +146,27 @@ async def test_completion_with_retries(sync_mode):
mock_completion.assert_called_once()
assert mock_completion.call_args.kwargs["num_retries"] == 0
assert mock_completion.call_args.kwargs["max_retries"] == 0
@pytest.mark.asyncio
async def test_acompletion_with_retries_retries():
"""
Test that the acompletion function is called num_retries number of times
"""
from unittest.mock import patch
with patch.object(litellm, "acompletion") as mock_completion:
async def timeout_fnc(*args, **kwargs):
await mock_completion(*args, **kwargs)
raise litellm.Timeout(message='test', model='gpt-3.5-turbo', llm_provider='mock')
with pytest.raises(litellm.Timeout):
await acompletion_with_retries(
model="gpt-3.5-turbo",
messages=[{"gm": "vibe", "role": "user"}],
num_retries=3,
original_function=timeout_fnc,
)
assert mock_completion.call_count == 3