From ab0bc87427ab15504e720d99e5c9a41259629408 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 25 Nov 2023 14:58:07 -0800 Subject: [PATCH] fix(router.py): check if fallbacks is none --- litellm/router.py | 10 ++++++++-- litellm/tests/test_acooldowns_router.py | 2 +- litellm/tests/test_router.py | 2 +- litellm/utils.py | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 4afe1d9d69..b4a21ccb5d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -281,7 +281,7 @@ class Router: original_exception = e try: self.print_verbose(f"Trying to fallback b/w models") - if isinstance(e, litellm.ContextWindowExceededError): + if isinstance(e, litellm.ContextWindowExceededError) and self.context_window_fallbacks is not None: fallback_model_group = None for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: @@ -302,6 +302,8 @@ class Router: except Exception as e: pass else: + if self.fallbacks is None: + raise original_exception self.print_verbose(f"inside model fallbacks: {self.fallbacks}") for item in self.fallbacks: if list(item.keys())[0] == model_group: @@ -374,9 +376,10 @@ class Router: self.print_verbose(f"An exception occurs {original_exception}") try: self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}") - if isinstance(e, litellm.ContextWindowExceededError): + if isinstance(e, litellm.ContextWindowExceededError) and self.context_window_fallbacks is not None: self.print_verbose(f"inside context window fallbacks: {self.context_window_fallbacks}") fallback_model_group = None + for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] @@ -396,6 +399,9 @@ class Router: except Exception as e: pass else: + if self.fallbacks is None: + raise original_exception + self.print_verbose(f"inside model fallbacks: {self.fallbacks}") fallback_model_group = None for item in self.fallbacks: diff --git a/litellm/tests/test_acooldowns_router.py b/litellm/tests/test_acooldowns_router.py index 3542fa9869..b720c0b93a 100644 --- a/litellm/tests/test_acooldowns_router.py +++ b/litellm/tests/test_acooldowns_router.py @@ -93,7 +93,7 @@ def test_multiple_deployments_parallel(): del futures[future] # Remove the done future with exception print(f"Remaining futures: {len(futures)}") - + router.reset() end_time = time.time() print(results) print(f"ELAPSED TIME: {end_time - start_time}") diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index cb9e37a61d..92380034af 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -116,7 +116,7 @@ def test_reading_key_from_model_list(): except Exception as e: os.environ["AZURE_API_KEY"] = old_api_key print(f"FAILED TEST") - pytest.fail("Got unexpected exception on router!", e) + pytest.fail(f"Got unexpected exception on router! - {e}") # test_reading_key_from_model_list() diff --git a/litellm/utils.py b/litellm/utils.py index b00e5ed860..04b8c4c5fe 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1991,7 +1991,7 @@ def get_optional_params( # use the openai defaults optional_params["temperature"] = temperature if max_tokens is not None: optional_params["max_tokens"] = max_tokens - if logit_bias != {}: + if logit_bias is not None: optional_params["logit_bias"] = logit_bias if top_p is not None: optional_params["p"] = top_p