fix(router.py): fix default retry logic

This commit is contained in:
Krrish Dholakia 2024-04-25 11:57:27 -07:00
parent f84e0f4a24
commit f1b2405fe0
6 changed files with 63 additions and 57 deletions

View file

@ -1,7 +1,7 @@
#### What this tests ####
# This tests litellm router
import sys, os, time
import sys, os, time, openai
import traceback, asyncio
import pytest
@ -18,6 +18,45 @@ from dotenv import load_dotenv
load_dotenv()
@pytest.mark.parametrize("num_retries", [None, 2])
@pytest.mark.parametrize("max_retries", [None, 4])
def test_router_num_retries_init(num_retries, max_retries):
"""
- test when num_retries set v/s not
- test client value when max retries set v/s not
"""
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"max_retries": max_retries,
},
"model_info": {"id": 12345},
},
],
num_retries=num_retries,
)
if num_retries is not None:
assert router.num_retries == num_retries
else:
assert router.num_retries == openai.DEFAULT_MAX_RETRIES
model_client = router._get_client(
{"model_info": {"id": 12345}}, client_type="async", kwargs={}
)
if max_retries is not None:
assert getattr(model_client, "max_retries") == max_retries
else:
assert getattr(model_client, "max_retries") == 0
def test_exception_raising():
# this tests if the router raises an exception when invalid params are set
# in this test both deployments have bad keys - Keep this test. It validates if the router raises the most recent exception