test - client side fallbacks

This commit is contained in:
Ishaan Jaff 2024-06-10 15:00:36 -07:00
parent 1bcdee8c99
commit 94210a86b4
2 changed files with 61 additions and 8 deletions

View file

@ -2058,7 +2058,7 @@ class Router:
## check for specific model group-specific fallbacks ## check for specific model group-specific fallbacks
if isinstance(fallbacks, list): if isinstance(fallbacks, list):
fallback_model_group = fallbacks fallback_model_group = fallbacks
else: elif isinstance(fallbacks, dict):
for idx, item in enumerate(fallbacks): for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
@ -2313,6 +2313,9 @@ class Router:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group = None fallback_model_group = None
generic_fallback_idx: Optional[int] = None generic_fallback_idx: Optional[int] = None
if isinstance(fallbacks, list):
fallback_model_group = fallbacks
elif isinstance(fallbacks, dict):
## check for specific model group-specific fallbacks ## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks): for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:

View file

@ -1059,3 +1059,53 @@ async def test_default_model_fallbacks(sync_mode, litellm_module_fallbacks):
assert isinstance(response, litellm.ModelResponse) assert isinstance(response, litellm.ModelResponse)
assert response.model is not None and response.model == "gpt-4o" assert response.model is not None and response.model == "gpt-4o"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_client_side_fallbacks_list(sync_mode):
"""
Tests Client Side Fallbacks
User can pass "fallbacks": ["gpt-3.5-turbo"] and this should work
"""
router = Router(
model_list=[
{
"model_name": "bad-model",
"litellm_params": {
"model": "openai/my-bad-model",
"api_key": "my-bad-api-key",
},
},
{
"model_name": "my-good-model",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
],
)
if sync_mode:
response = router.completion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
fallbacks=["my-good-model"],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
else:
response = await router.acompletion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
fallbacks=["my-good-model"],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
assert isinstance(response, litellm.ModelResponse)
assert response.model is not None and response.model == "gpt-4o"