From 87ff26ff2798a0e70673227aba6d34810ebef7b8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 30 Apr 2024 15:23:19 -0700 Subject: [PATCH 1/6] fix(router.py): unify retry timeout logic across sync + async function_with_retries --- litellm/router.py | 148 ++++++++++++++++++++--------------- litellm/tests/test_router.py | 36 +++++++++ 2 files changed, 119 insertions(+), 65 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 8ea1a124a..f173d52fb 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1418,6 +1418,13 @@ class Router: traceback.print_exc() raise original_exception + async def _async_router_should_retry( + self, e: Exception, remaining_retries: int, num_retries: int + ): + """ + Calculate back-off, then retry + """ + async def async_function_with_retries(self, *args, **kwargs): verbose_router_logger.debug( f"Inside async function with retries: args - {args}; kwargs - {kwargs}" @@ -1450,40 +1457,47 @@ class Router: raise original_exception ### RETRY #### check if it should retry + back-off if required - if "No models available" in str( - e - ) or RouterErrors.no_deployments_available.value in str(e): - timeout = litellm._calculate_retry_after( - remaining_retries=num_retries, - max_retries=num_retries, - min_timeout=self.retry_after, - ) - await asyncio.sleep(timeout) - elif RouterErrors.user_defined_ratelimit_error.value in str(e): - raise e # don't wait to retry if deployment hits user-defined rate-limit + # if "No models available" in str( + # e + # ) or RouterErrors.no_deployments_available.value in str(e): + # timeout = litellm._calculate_retry_after( + # remaining_retries=num_retries, + # max_retries=num_retries, + # min_timeout=self.retry_after, + # ) + # await asyncio.sleep(timeout) + # elif RouterErrors.user_defined_ratelimit_error.value in str(e): + # raise e # don't wait to retry if deployment hits user-defined rate-limit - elif hasattr(original_exception, "status_code") and litellm._should_retry( - status_code=original_exception.status_code - ): - if hasattr(original_exception, "response") and hasattr( - original_exception.response, "headers" - ): - timeout = litellm._calculate_retry_after( - remaining_retries=num_retries, - max_retries=num_retries, - response_headers=original_exception.response.headers, - min_timeout=self.retry_after, - ) - else: - timeout = litellm._calculate_retry_after( - remaining_retries=num_retries, - max_retries=num_retries, - min_timeout=self.retry_after, - ) - await asyncio.sleep(timeout) - else: - raise original_exception + # elif hasattr(original_exception, "status_code") and litellm._should_retry( + # status_code=original_exception.status_code + # ): + # if hasattr(original_exception, "response") and hasattr( + # original_exception.response, "headers" + # ): + # timeout = litellm._calculate_retry_after( + # remaining_retries=num_retries, + # max_retries=num_retries, + # response_headers=original_exception.response.headers, + # min_timeout=self.retry_after, + # ) + # else: + # timeout = litellm._calculate_retry_after( + # remaining_retries=num_retries, + # max_retries=num_retries, + # min_timeout=self.retry_after, + # ) + # await asyncio.sleep(timeout) + # else: + # raise original_exception + ### RETRY + _timeout = self._router_should_retry( + e=original_exception, + remaining_retries=num_retries, + num_retries=num_retries, + ) + await asyncio.sleep(_timeout) ## LOGGING if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) @@ -1505,34 +1519,37 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - if "No models available" in str(e): - timeout = litellm._calculate_retry_after( - remaining_retries=remaining_retries, - max_retries=num_retries, - min_timeout=self.retry_after, - ) - await asyncio.sleep(timeout) - elif ( - hasattr(e, "status_code") - and hasattr(e, "response") - and litellm._should_retry(status_code=e.status_code) - ): - if hasattr(e.response, "headers"): - timeout = litellm._calculate_retry_after( - remaining_retries=remaining_retries, - max_retries=num_retries, - response_headers=e.response.headers, - min_timeout=self.retry_after, - ) - else: - timeout = litellm._calculate_retry_after( - remaining_retries=remaining_retries, - max_retries=num_retries, - min_timeout=self.retry_after, - ) - await asyncio.sleep(timeout) - else: - raise e + # if "No models available" in str(e): + # timeout = litellm._calculate_retry_after( + # remaining_retries=remaining_retries, + # max_retries=num_retries, + # min_timeout=self.retry_after, + # ) + # await asyncio.sleep(timeout) + # elif ( + # hasattr(e, "status_code") + # and hasattr(e, "response") + # and litellm._should_retry(status_code=e.status_code) + # ): + # if hasattr(e.response, "headers"): + # timeout = litellm._calculate_retry_after( + # remaining_retries=remaining_retries, + # max_retries=num_retries, + # response_headers=e.response.headers, + # min_timeout=self.retry_after, + # ) + # else: + # timeout = litellm._calculate_retry_after( + # remaining_retries=remaining_retries, + # max_retries=num_retries, + # min_timeout=self.retry_after, + # ) + _timeout = self._router_should_retry( + e=original_exception, + remaining_retries=remaining_retries, + num_retries=num_retries, + ) + await asyncio.sleep(_timeout) raise original_exception def function_with_fallbacks(self, *args, **kwargs): @@ -1625,7 +1642,7 @@ class Router: def _router_should_retry( self, e: Exception, remaining_retries: int, num_retries: int - ): + ) -> int | float: """ Calculate back-off, then retry """ @@ -1636,14 +1653,13 @@ class Router: response_headers=e.response.headers, min_timeout=self.retry_after, ) - time.sleep(timeout) else: timeout = litellm._calculate_retry_after( remaining_retries=remaining_retries, max_retries=num_retries, min_timeout=self.retry_after, ) - time.sleep(timeout) + return timeout def function_with_retries(self, *args, **kwargs): """ @@ -1677,11 +1693,12 @@ class Router: if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) ### RETRY - self._router_should_retry( + _timeout = self._router_should_retry( e=original_exception, remaining_retries=num_retries, num_retries=num_retries, ) + time.sleep(_timeout) for current_attempt in range(num_retries): verbose_router_logger.debug( f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}" @@ -1695,11 +1712,12 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - self._router_should_retry( + _timeout = self._router_should_retry( e=e, remaining_retries=remaining_retries, num_retries=num_retries, ) + time.sleep(_timeout) raise original_exception ### HELPER FUNCTIONS diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 7520ac75f..8c6b9fa01 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -104,6 +104,42 @@ def test_router_timeout_init(timeout, ssl_verify): ) +@pytest.mark.parametrize("sync_mode", [False, True]) +@pytest.mark.asyncio +async def test_router_retries(sync_mode): + """ + - make sure retries work as expected + """ + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo", "api_key": "bad-key"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, + }, + ] + + router = Router(model_list=model_list, num_retries=2) + + if sync_mode: + router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + else: + await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + + @pytest.mark.parametrize( "mistral_api_base", [ From 668a5353eeb86d3d09530d0219e4200406f251ef Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 30 Apr 2024 15:35:16 -0700 Subject: [PATCH 2/6] fix(router.py): fix linting issue --- litellm/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index f173d52fb..4f1d6b96e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1642,7 +1642,7 @@ class Router: def _router_should_retry( self, e: Exception, remaining_retries: int, num_retries: int - ) -> int | float: + ) -> Union[int, float]: """ Calculate back-off, then retry """ From 8ee51a96f47113befbf349e57d5a43e4f10f438f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 30 Apr 2024 16:42:10 -0700 Subject: [PATCH 3/6] test(test_router_debug_logs.py): fix retry logic --- litellm/tests/test_router_debug_logs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/tests/test_router_debug_logs.py b/litellm/tests/test_router_debug_logs.py index 0bc711b15..19a83287f 100644 --- a/litellm/tests/test_router_debug_logs.py +++ b/litellm/tests/test_router_debug_logs.py @@ -46,6 +46,7 @@ def test_async_fallbacks(caplog): router = Router( model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["azure/gpt-3.5-turbo"]}], + num_retries=1, ) user_message = "Hello, how are you?" @@ -83,6 +84,7 @@ def test_async_fallbacks(caplog): expected_logs = [ "Intialized router with Routing strategy: simple-shuffle\n\nRouting fallbacks: [{'gpt-3.5-turbo': ['azure/gpt-3.5-turbo']}]\n\nRouting context window fallbacks: None\n\nRouter Redis Caching=None", "litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m", + "litellm.acompletion(model=None)\x1b[31m Exception No deployments available for selected model, passed model=gpt-3.5-turbo\x1b[0m", "Falling back to model_group = azure/gpt-3.5-turbo", "litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m", ] From 1baad80c7d5338dc4cdb6794811d6728d97f922e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 30 Apr 2024 17:54:00 -0700 Subject: [PATCH 4/6] fix(router.py): cooldown deployments, for 401 errors --- litellm/main.py | 13 +++ litellm/router.py | 34 +++++-- litellm/tests/conftest.py | 1 + litellm/tests/test_router.py | 2 + litellm/tests/test_router_fallbacks.py | 8 +- litellm/tests/test_router_retries.py | 121 +++++++++++++++++++++++++ 6 files changed, 165 insertions(+), 14 deletions(-) create mode 100644 litellm/tests/test_router_retries.py diff --git a/litellm/main.py b/litellm/main.py index 454f7f716..569418eca 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -387,6 +387,19 @@ def mock_completion( - If 'stream' is True, it returns a response that mimics the behavior of a streaming completion. """ try: + ## LOGGING + logging.pre_call( + input=messages, + api_key="mock-key", + ) + if isinstance(mock_response, Exception): + raise litellm.APIError( + status_code=500, # type: ignore + message=str(mock_response), + llm_provider="openai", # type: ignore + model=model, # type: ignore + request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), + ) model_response = ModelResponse(stream=stream) if stream is True: # don't try to access stream object, diff --git a/litellm/router.py b/litellm/router.py index 4f1d6b96e..bcf2d2cb6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1418,13 +1418,6 @@ class Router: traceback.print_exc() raise original_exception - async def _async_router_should_retry( - self, e: Exception, remaining_retries: int, num_retries: int - ): - """ - Calculate back-off, then retry - """ - async def async_function_with_retries(self, *args, **kwargs): verbose_router_logger.debug( f"Inside async function with retries: args - {args}; kwargs - {kwargs}" @@ -1674,6 +1667,7 @@ class Router: context_window_fallbacks = kwargs.pop( "context_window_fallbacks", self.context_window_fallbacks ) + try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = original_function(*args, **kwargs) @@ -1751,10 +1745,11 @@ class Router: ) # i.e. azure metadata = kwargs.get("litellm_params", {}).get("metadata", None) _model_info = kwargs.get("litellm_params", {}).get("model_info", {}) + if isinstance(_model_info, dict): deployment_id = _model_info.get("id", None) self._set_cooldown_deployments( - deployment_id + exception_status=exception_status, deployment=deployment_id ) # setting deployment_id in cooldown deployments if custom_llm_provider: model_name = f"{custom_llm_provider}/{model_name}" @@ -1814,9 +1809,15 @@ class Router: key=rpm_key, value=request_count, local_only=True ) # don't change existing ttl - def _set_cooldown_deployments(self, deployment: Optional[str] = None): + def _set_cooldown_deployments( + self, exception_status: Union[str, int], deployment: Optional[str] = None + ): """ Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute + + or + + the exception is not one that should be immediately retried (e.g. 401) """ if deployment is None: return @@ -1833,7 +1834,20 @@ class Router: f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}" ) cooldown_time = self.cooldown_time or 1 - if updated_fails > self.allowed_fails: + + if isinstance(exception_status, str): + try: + exception_status = int(exception_status) + except Exception as e: + verbose_router_logger.debug( + "Unable to cast exception status to int {}. Defaulting to status=500.".format( + exception_status + ) + ) + exception_status = 500 + _should_retry = litellm._should_retry(status_code=exception_status) + + if updated_fails > self.allowed_fails or _should_retry == False: # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls cached_value = self.cache.get_cache(key=cooldown_key) diff --git a/litellm/tests/conftest.py b/litellm/tests/conftest.py index 4cd277b31..8c2ce781f 100644 --- a/litellm/tests/conftest.py +++ b/litellm/tests/conftest.py @@ -19,6 +19,7 @@ def setup_and_teardown(): 0, os.path.abspath("../..") ) # Adds the project directory to the system path import litellm + from litellm import Router importlib.reload(litellm) import asyncio diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 8c6b9fa01..2d277d749 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -1154,6 +1154,7 @@ def test_consistent_model_id(): assert id1 == id2 +@pytest.mark.skip(reason="local test") def test_reading_keys_os_environ(): import openai @@ -1253,6 +1254,7 @@ def test_reading_keys_os_environ(): # test_reading_keys_os_environ() +@pytest.mark.skip(reason="local test") def test_reading_openai_keys_os_environ(): import openai diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 364319929..a4110518b 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -22,10 +22,10 @@ class MyCustomHandler(CustomLogger): def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") print( - f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}" + f"previous_models: {kwargs['litellm_params']['metadata'].get('previous_models', None)}" ) - self.previous_models += len( - kwargs["litellm_params"]["metadata"]["previous_models"] + self.previous_models = len( + kwargs["litellm_params"]["metadata"].get("previous_models", []) ) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": }]} print(f"self.previous_models: {self.previous_models}") @@ -140,7 +140,7 @@ def test_sync_fallbacks(): @pytest.mark.asyncio async def test_async_fallbacks(): - litellm.set_verbose = False + litellm.set_verbose = True model_list = [ { # list of model deployments "model_name": "azure/gpt-3.5-turbo", # openai model name diff --git a/litellm/tests/test_router_retries.py b/litellm/tests/test_router_retries.py new file mode 100644 index 000000000..3ed08dfd9 --- /dev/null +++ b/litellm/tests/test_router_retries.py @@ -0,0 +1,121 @@ +#### What this tests #### +# This tests calling router with fallback models + +import sys, os, time +import traceback, asyncio +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import litellm +from litellm import Router +from litellm.integrations.custom_logger import CustomLogger + + +class MyCustomHandler(CustomLogger): + success: bool = False + failure: bool = False + previous_models: int = 0 + + def log_pre_api_call(self, model, messages, kwargs): + print(f"Pre-API Call") + print( + f"previous_models: {kwargs['litellm_params']['metadata'].get('previous_models', None)}" + ) + self.previous_models = len( + kwargs["litellm_params"]["metadata"].get("previous_models", []) + ) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": }]} + print(f"self.previous_models: {self.previous_models}") + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + print( + f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}" + ) + + def log_stream_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Stream") + + def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Stream") + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Success") + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Success") + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Failure") + + +""" +Test sync + async + +- Authorization Errors +- Random API Error +""" + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("error_type", ["Authorization Error", "API Error"]) +@pytest.mark.asyncio +async def test_router_retries_errors(sync_mode, error_type): + """ + - Auth Error -> 0 retries + - API Error -> 2 retries + """ + + _api_key = ( + "bad-key" if error_type == "Authorization Error" else os.getenv("AZURE_API_KEY") + ) + print(f"_api_key: {_api_key}") + model_list = [ + { + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": _api_key, + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + ] + + router = Router(model_list=model_list, allowed_fails=3) + + customHandler = MyCustomHandler() + litellm.callbacks = [customHandler] + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + + kwargs = { + "model": "azure/gpt-3.5-turbo", + "messages": messages, + "mock_response": ( + None + if error_type == "Authorization Error" + else Exception("Invalid Request") + ), + } + + try: + if sync_mode: + response = router.completion(**kwargs) + else: + response = await router.acompletion(**kwargs) + except Exception as e: + pass + + await asyncio.sleep( + 0.05 + ) # allow a delay as success_callbacks are on a separate thread + print(f"customHandler.previous_models: {customHandler.previous_models}") + + if error_type == "Authorization Error": + assert customHandler.previous_models == 0 # 0 retries + else: + assert customHandler.previous_models == 2 # 2 retries From bc5c9d7da9ef58c8af4c21a02b92a525d8386410 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 30 Apr 2024 18:48:39 -0700 Subject: [PATCH 5/6] fix(test_router_fallbacks.py): fix tests --- litellm/router.py | 25 --------------- litellm/tests/test_router_fallbacks.py | 44 ++++++++++++++------------ 2 files changed, 24 insertions(+), 45 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index bcf2d2cb6..14efc2a56 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1512,31 +1512,6 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - # if "No models available" in str(e): - # timeout = litellm._calculate_retry_after( - # remaining_retries=remaining_retries, - # max_retries=num_retries, - # min_timeout=self.retry_after, - # ) - # await asyncio.sleep(timeout) - # elif ( - # hasattr(e, "status_code") - # and hasattr(e, "response") - # and litellm._should_retry(status_code=e.status_code) - # ): - # if hasattr(e.response, "headers"): - # timeout = litellm._calculate_retry_after( - # remaining_retries=remaining_retries, - # max_retries=num_retries, - # response_headers=e.response.headers, - # min_timeout=self.retry_after, - # ) - # else: - # timeout = litellm._calculate_retry_after( - # remaining_retries=remaining_retries, - # max_retries=num_retries, - # min_timeout=self.retry_after, - # ) _timeout = self._router_should_retry( e=original_exception, remaining_retries=remaining_retries, diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index a4110518b..7027050e1 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -127,7 +127,7 @@ def test_sync_fallbacks(): response = router.completion(**kwargs) print(f"response: {response}") time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + assert customHandler.previous_models == 4 print("Passed ! Test router_fallbacks: test_sync_fallbacks()") router.reset() @@ -209,12 +209,13 @@ async def test_async_fallbacks(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: + kwargs["model"] = "azure/gpt-3.5-turbo" response = await router.acompletion(**kwargs) print(f"customHandler.previous_models: {customHandler.previous_models}") await asyncio.sleep( 0.05 ) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback router.reset() except litellm.Timeout as e: pass @@ -258,7 +259,6 @@ def test_sync_fallbacks_embeddings(): model_list=model_list, fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}], set_verbose=False, - num_retries=0, ) customHandler = MyCustomHandler() litellm.callbacks = [customHandler] @@ -269,7 +269,7 @@ def test_sync_fallbacks_embeddings(): response = router.embedding(**kwargs) print(f"customHandler.previous_models: {customHandler.previous_models}") time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback router.reset() except litellm.Timeout as e: pass @@ -323,7 +323,7 @@ async def test_async_fallbacks_embeddings(): await asyncio.sleep( 0.05 ) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback router.reset() except litellm.Timeout as e: pass @@ -394,7 +394,7 @@ def test_dynamic_fallbacks_sync(): }, ] - router = Router(model_list=model_list, set_verbose=True, num_retries=0) + router = Router(model_list=model_list, set_verbose=True) kwargs = {} kwargs["model"] = "azure/gpt-3.5-turbo" kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] @@ -402,7 +402,7 @@ def test_dynamic_fallbacks_sync(): response = router.completion(**kwargs) print(f"response: {response}") time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback router.reset() except Exception as e: pytest.fail(f"An exception occurred - {e}") @@ -488,7 +488,7 @@ async def test_dynamic_fallbacks_async(): await asyncio.sleep( 0.05 ) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback router.reset() except Exception as e: pytest.fail(f"An exception occurred - {e}") @@ -573,7 +573,7 @@ async def test_async_fallbacks_streaming(): await asyncio.sleep( 0.05 ) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback router.reset() except litellm.Timeout as e: pass @@ -766,10 +766,10 @@ def test_usage_based_routing_fallbacks(): load_dotenv() # Constants for TPM and RPM allocation - AZURE_FAST_TPM = 3 - AZURE_BASIC_TPM = 4 - OPENAI_TPM = 400 - ANTHROPIC_TPM = 100000 + AZURE_FAST_RPM = 3 + AZURE_BASIC_RPM = 4 + OPENAI_RPM = 10 + ANTHROPIC_RPM = 100000 def get_azure_params(deployment_name: str): params = { @@ -798,22 +798,26 @@ def test_usage_based_routing_fallbacks(): { "model_name": "azure/gpt-4-fast", "litellm_params": get_azure_params("chatgpt-v-2"), - "tpm": AZURE_FAST_TPM, + "model_info": {"id": 1}, + "rpm": AZURE_FAST_RPM, }, { "model_name": "azure/gpt-4-basic", "litellm_params": get_azure_params("chatgpt-v-2"), - "tpm": AZURE_BASIC_TPM, + "model_info": {"id": 2}, + "rpm": AZURE_BASIC_RPM, }, { "model_name": "openai-gpt-4", "litellm_params": get_openai_params("gpt-3.5-turbo"), - "tpm": OPENAI_TPM, + "model_info": {"id": 3}, + "rpm": OPENAI_RPM, }, { "model_name": "anthropic-claude-instant-1.2", "litellm_params": get_anthropic_params("claude-instant-1.2"), - "tpm": ANTHROPIC_TPM, + "model_info": {"id": 4}, + "rpm": ANTHROPIC_RPM, }, ] # litellm.set_verbose=True @@ -844,10 +848,10 @@ def test_usage_based_routing_fallbacks(): mock_response="very nice to meet you", ) print("response: ", response) - print("response._hidden_params: ", response._hidden_params) + print(f"response._hidden_params: {response._hidden_params}") # in this test, we expect azure/gpt-4 fast to fail, then azure-gpt-4 basic to fail and then openai-gpt-4 to pass # the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM - assert response._hidden_params["custom_llm_provider"] == "openai" + assert response._hidden_params["model_id"] == "1" # now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2 for i in range(20): @@ -861,7 +865,7 @@ def test_usage_based_routing_fallbacks(): print("response._hidden_params: ", response._hidden_params) if i == 19: # by the 19th call we should have hit TPM LIMIT for OpenAI, it should fallback to anthropic-claude-instant-1.2 - assert response._hidden_params["custom_llm_provider"] == "anthropic" + assert response._hidden_params["model_id"] == "4" except Exception as e: pytest.fail(f"An exception occurred {e}") From 4761345311f98ada5bdf6d7bcf0e89b5489429b5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 30 Apr 2024 19:30:18 -0700 Subject: [PATCH 6/6] fix(main.py): fix mock completion response --- litellm/main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 569418eca..cdea40d11 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -360,7 +360,7 @@ def mock_completion( model: str, messages: List, stream: Optional[bool] = False, - mock_response: str = "This is a mock request", + mock_response: Union[str, Exception] = "This is a mock request", logging=None, **kwargs, ): @@ -388,10 +388,11 @@ def mock_completion( """ try: ## LOGGING - logging.pre_call( - input=messages, - api_key="mock-key", - ) + if logging is not None: + logging.pre_call( + input=messages, + api_key="mock-key", + ) if isinstance(mock_response, Exception): raise litellm.APIError( status_code=500, # type: ignore