From fa713abfc39cc62ec5e6ffe2713b705b87ad0133 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 25 Nov 2023 18:46:45 -0800 Subject: [PATCH] fix(router.py): check for fallbacks in completion params for router --- litellm/router.py | 27 ++++++++------- litellm/tests/test_router_fallbacks.py | 46 +++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 1ec581a0b0..c17fcf8774 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -272,6 +272,8 @@ class Router: If it fails after num_retries, fall back to another model group """ model_group = kwargs.get("model") + fallbacks = kwargs.pop("fallbacks", self.fallbacks) + context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) try: response = await self.async_function_with_retries(*args, **kwargs) self.print_verbose(f'Async Response: {response}') @@ -281,9 +283,9 @@ class Router: original_exception = e try: self.print_verbose(f"Trying to fallback b/w models") - if isinstance(e, litellm.ContextWindowExceededError) and self.context_window_fallbacks is not None: + if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None: fallback_model_group = None - for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] + for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break @@ -301,9 +303,9 @@ class Router: return response except Exception as e: pass - elif self.fallbacks is not None: - self.print_verbose(f"inside model fallbacks: {self.fallbacks}") - for item in self.fallbacks: + elif fallbacks is not None: + self.print_verbose(f"inside model fallbacks: {fallbacks}") + for item in fallbacks: if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break @@ -365,7 +367,8 @@ class Router: If it fails after num_retries, fall back to another model group """ model_group = kwargs.get("model") - + fallbacks = kwargs.pop("fallbacks", self.fallbacks) + context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) try: response = self.function_with_retries(*args, **kwargs) return response @@ -374,11 +377,11 @@ class Router: self.print_verbose(f"An exception occurs {original_exception}") try: self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}") - if isinstance(e, litellm.ContextWindowExceededError) and self.context_window_fallbacks is not None: - self.print_verbose(f"inside context window fallbacks: {self.context_window_fallbacks}") + if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None: + self.print_verbose(f"inside context window fallbacks: {context_window_fallbacks}") fallback_model_group = None - for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] + for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break @@ -396,10 +399,10 @@ class Router: return response except Exception as e: pass - elif self.fallbacks is not None: - self.print_verbose(f"inside model fallbacks: {self.fallbacks}") + elif fallbacks is not None: + self.print_verbose(f"inside model fallbacks: {fallbacks}") fallback_model_group = None - for item in self.fallbacks: + for item in fallbacks: if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index cdcc8cc2da..0501ea8a2f 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -82,7 +82,7 @@ def test_sync_fallbacks(): router.flush_cache() except Exception as e: print(e) -test_sync_fallbacks() +# test_sync_fallbacks() def test_async_fallbacks(): litellm.set_verbose = False @@ -101,7 +101,7 @@ def test_async_fallbacks(): asyncio.run(test_get_response()) -test_async_fallbacks() +# test_async_fallbacks() def test_sync_context_window_fallbacks(): try: @@ -110,8 +110,46 @@ def test_sync_context_window_fallbacks(): kwargs["messages"] = [{"role": "user", "content": sample_text}] response = router.completion(**kwargs) print(f"response: {response}") - router.flush_cache() + router.reset() except Exception as e: print(e) -# test_sync_context_window_fallbacks() \ No newline at end of file +# test_sync_context_window_fallbacks() + +def test_dynamic_fallbacks_sync(): + """ + Allow setting the fallback in the router.completion() call. + """ + try: + router = Router(model_list=model_list, set_verbose=True) + kwargs = {} + kwargs["model"] = "azure/gpt-3.5-turbo" + kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] + kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}] + response = router.completion(**kwargs) + print(f"response: {response}") + router.reset() + except Exception as e: + pytest.fail(f"An exception occurred - {e}") + +# test_dynamic_fallbacks_sync() + +def test_dynamic_fallbacks_async(): + """ + Allow setting the fallback in the router.completion() call. + """ + async def test_get_response(): + try: + router = Router(model_list=model_list, set_verbose=True) + kwargs = {} + kwargs["model"] = "azure/gpt-3.5-turbo" + kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] + kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}] + response = await router.acompletion(**kwargs) + print(f"response: {response}") + router.reset() + except Exception as e: + pytest.fail(f"An exception occurred - {e}") + asyncio.run(test_get_response()) + +# test_dynamic_fallbacks_async() \ No newline at end of file