diff --git a/litellm/router.py b/litellm/router.py index c4ea521c5c..adf8f4897e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2058,7 +2058,7 @@ class Router: ## check for specific model group-specific fallbacks if isinstance(fallbacks, list): fallback_model_group = fallbacks - else: + elif isinstance(fallbacks, dict): for idx, item in enumerate(fallbacks): if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] @@ -2313,13 +2313,16 @@ class Router: verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") fallback_model_group = None generic_fallback_idx: Optional[int] = None - ## check for specific model group-specific fallbacks - for idx, item in enumerate(fallbacks): - if list(item.keys())[0] == model_group: - fallback_model_group = item[model_group] - break - elif list(item.keys())[0] == "*": - generic_fallback_idx = idx + if isinstance(fallbacks, list): + fallback_model_group = fallbacks + elif isinstance(fallbacks, dict): + ## check for specific model group-specific fallbacks + for idx, item in enumerate(fallbacks): + if list(item.keys())[0] == model_group: + fallback_model_group = item[model_group] + break + elif list(item.keys())[0] == "*": + generic_fallback_idx = idx ## if none, check for generic fallback if ( fallback_model_group is None diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 6e483b9fed..c6e0e54111 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -1059,3 +1059,53 @@ async def test_default_model_fallbacks(sync_mode, litellm_module_fallbacks): assert isinstance(response, litellm.ModelResponse) assert response.model is not None and response.model == "gpt-4o" + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_client_side_fallbacks_list(sync_mode): + """ + + Tests Client Side Fallbacks + + User can pass "fallbacks": ["gpt-3.5-turbo"] and this should work + + """ + router = Router( + model_list=[ + { + "model_name": "bad-model", + "litellm_params": { + "model": "openai/my-bad-model", + "api_key": "my-bad-api-key", + }, + }, + { + "model_name": "my-good-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + ], + ) + + if sync_mode: + response = router.completion( + model="bad-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + fallbacks=["my-good-model"], + mock_testing_fallbacks=True, + mock_response="Hey! nice day", + ) + else: + response = await router.acompletion( + model="bad-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + fallbacks=["my-good-model"], + mock_testing_fallbacks=True, + mock_response="Hey! nice day", + ) + + assert isinstance(response, litellm.ModelResponse) + assert response.model is not None and response.model == "gpt-4o"