diff --git a/litellm/__init__.py b/litellm/__init__.py index 16395b27f3..c5e407eb1f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -219,6 +219,7 @@ max_end_user_budget: Optional[float] = None #### RELIABILITY #### request_timeout: Optional[float] = 6000 num_retries: Optional[int] = None # per model endpoint +default_fallbacks: Optional[List] = None fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None allowed_fails: int = 0 diff --git a/litellm/router.py b/litellm/router.py index b4603c6d07..3c459592cd 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -263,11 +263,12 @@ class Router: self.retry_after = retry_after self.routing_strategy = routing_strategy self.fallbacks = fallbacks or litellm.fallbacks - if default_fallbacks is not None: + if default_fallbacks is not None or litellm.default_fallbacks is not None: + _fallbacks = default_fallbacks or litellm.default_fallbacks if self.fallbacks is not None: - self.fallbacks.append({"*": default_fallbacks}) + self.fallbacks.append({"*": _fallbacks}) else: - self.fallbacks = [{"*": default_fallbacks}] + self.fallbacks = [{"*": _fallbacks}] self.context_window_fallbacks = ( context_window_fallbacks or litellm.context_window_fallbacks ) diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 4ab97b274d..6e483b9fed 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -1010,13 +1010,16 @@ async def test_service_unavailable_fallbacks(sync_mode): @pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("litellm_module_fallbacks", [True, False]) @pytest.mark.asyncio -async def test_default_model_fallbacks(sync_mode): +async def test_default_model_fallbacks(sync_mode, litellm_module_fallbacks): """ Related issue - https://github.com/BerriAI/litellm/issues/3623 If model misconfigured, setup a default model for generic fallback """ + if litellm_module_fallbacks: + litellm.default_fallbacks = ["my-good-model"] router = Router( model_list=[ { @@ -1034,7 +1037,9 @@ async def test_default_model_fallbacks(sync_mode): }, }, ], - default_fallbacks=["my-good-model"], + default_fallbacks=( + ["my-good-model"] if litellm_module_fallbacks == False else None + ), ) if sync_mode: