From 078fe97053becfcbe118a9941d50c34d79a09ede Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 20 Aug 2024 12:50:20 -0700 Subject: [PATCH] fix fallbacks dont recurse on the same fallback --- litellm/router.py | 114 +++++------------- .../router_utils/fallback_event_handlers.py | 49 +++++++- litellm/tests/test_router_fallbacks.py | 46 +++++++ 3 files changed, 122 insertions(+), 87 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 46fb934c63..0f34775207 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -62,6 +62,7 @@ from litellm.router_utils.cooldown_callbacks import router_cooldown_handler from litellm.router_utils.fallback_event_handlers import ( log_failure_fallback_event, log_success_fallback_event, + run_async_fallback, ) from litellm.router_utils.handle_error import send_llm_exception_alert from litellm.scheduler import FlowItem, Scheduler @@ -2383,34 +2384,16 @@ class Router: if fallback_model_group is None: raise original_exception - for mg in fallback_model_group: - """ - Iterate through the model groups and try calling that deployment - """ - try: - kwargs["model"] = mg - kwargs.setdefault("metadata", {}).update( - {"model_group": mg} - ) # update model_group used, if fallbacks are done - response = await self.async_function_with_retries( - *args, **kwargs - ) - verbose_router_logger.info( - "Successful fallback b/w models." - ) - # callback for successfull_fallback_event(): - await log_success_fallback_event( - original_model_group=original_model_group, - kwargs=kwargs, - ) + response = await run_async_fallback( + *args, + litellm_router=self, + fallback_model_group=fallback_model_group, + original_model_group=original_model_group, + original_exception=original_exception, + **kwargs, + ) + return response - return response - except Exception as e: - await log_failure_fallback_event( - original_model_group=original_model_group, - kwargs=kwargs, - ) - pass else: error_message = "model={}. context_window_fallbacks={}. fallbacks={}.\n\nSet 'context_window_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format( model_group, context_window_fallbacks, fallbacks @@ -2436,33 +2419,15 @@ class Router: if fallback_model_group is None: raise original_exception - for mg in fallback_model_group: - """ - Iterate through the model groups and try calling that deployment - """ - try: - kwargs["model"] = mg - kwargs.setdefault("metadata", {}).update( - {"model_group": mg} - ) # update model_group used, if fallbacks are done - response = await self.async_function_with_retries( - *args, **kwargs - ) - verbose_router_logger.info( - "Successful fallback b/w models." - ) - # callback for successfull_fallback_event(): - await log_success_fallback_event( - original_model_group=original_model_group, - kwargs=kwargs, - ) - return response - except Exception as e: - await log_failure_fallback_event( - original_model_group=original_model_group, - kwargs=kwargs, - ) - pass + response = await run_async_fallback( + *args, + litellm_router=self, + fallback_model_group=fallback_model_group, + original_model_group=original_model_group, + original_exception=original_exception, + **kwargs, + ) + return response else: error_message = "model={}. content_policy_fallback={}. fallbacks={}.\n\nSet 'content_policy_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format( model_group, content_policy_fallbacks, fallbacks @@ -2502,39 +2467,16 @@ class Router: if hasattr(original_exception, "message"): original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" raise original_exception - for mg in fallback_model_group: - """ - Iterate through the model groups and try calling that deployment - """ - try: - ## LOGGING - kwargs = self.log_retry(kwargs=kwargs, e=original_exception) - verbose_router_logger.info( - f"Falling back to model_group = {mg}" - ) - kwargs["model"] = mg - kwargs.setdefault("metadata", {}).update( - {"model_group": mg} - ) # update model_group used, if fallbacks are done - response = await self.async_function_with_fallbacks( - *args, **kwargs - ) - verbose_router_logger.info( - "Successful fallback b/w models." - ) - # callback for successfull_fallback_event(): - await log_success_fallback_event( - original_model_group=original_model_group, - kwargs=kwargs, - ) - return response - except Exception as e: - await log_failure_fallback_event( - original_model_group=original_model_group, - kwargs=kwargs, - ) - raise e + response = await run_async_fallback( + *args, + litellm_router=self, + fallback_model_group=fallback_model_group, + original_model_group=original_model_group, + original_exception=original_exception, + **kwargs, + ) + return response except Exception as new_exception: verbose_router_logger.error( "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format( diff --git a/litellm/router_utils/fallback_event_handlers.py b/litellm/router_utils/fallback_event_handlers.py index 98d9cd92de..02465a0148 100644 --- a/litellm/router_utils/fallback_event_handlers.py +++ b/litellm/router_utils/fallback_event_handlers.py @@ -1,9 +1,56 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, List, Tuple import litellm from litellm._logging import verbose_router_logger from litellm.integrations.custom_logger import CustomLogger +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +async def run_async_fallback( + litellm_router: LitellmRouter, + *args: Tuple[Any], + fallback_model_group: List[str], + original_model_group: str, + original_exception: Exception, + **kwargs, +) -> Any: + """ + Iterate through the model groups and try calling that deployment. + """ + error_from_fallbacks = original_exception + for mg in fallback_model_group: + if mg == original_model_group: + continue + try: + # LOGGING + kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception) + verbose_router_logger.info(f"Falling back to model_group = {mg}") + kwargs["model"] = mg + kwargs.setdefault("metadata", {}).update( + {"model_group": mg} + ) # update model_group used, if fallbacks are done + response = await litellm_router.async_function_with_fallbacks( + *args, **kwargs + ) + verbose_router_logger.info("Successful fallback b/w models.") + # callback for successfull_fallback_event(): + await log_success_fallback_event( + original_model_group=original_model_group, kwargs=kwargs + ) + return response + except Exception as e: + error_from_fallbacks = e + await log_failure_fallback_event( + original_model_group=original_model_group, kwargs=kwargs + ) + raise error_from_fallbacks + async def log_success_fallback_event(original_model_group: str, kwargs: dict): for _callback in litellm.callbacks: diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 2c552a64bf..dddae151ac 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -1185,3 +1185,49 @@ async def test_router_content_policy_fallbacks( ) assert response.model == "my-fake-model" + + +@pytest.mark.parametrize("sync_mode", [False]) +@pytest.mark.asyncio +async def test_using_default_fallback(sync_mode): + """ + Tests Client Side Fallbacks + + User can pass "fallbacks": ["gpt-3.5-turbo"] and this should work + + """ + litellm.set_verbose = True + + import logging + + from litellm._logging import verbose_logger, verbose_router_logger + + verbose_logger.setLevel(logging.DEBUG) + verbose_router_logger.setLevel(logging.DEBUG) + litellm.default_fallbacks = ["very-bad-model"] + router = Router( + model_list=[ + { + "model_name": "openai/*", + "litellm_params": { + "model": "openai/*", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + ], + ) + try: + if sync_mode: + response = router.completion( + model="openai/foo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + else: + response = await router.acompletion( + model="openai/foo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + print("got response=", response) + pytest.fail(f"Expected call to fail we passed model=openai/foo") + except Exception as e: + print("got exception = ", e)