From d0a30529374ac45923f0038f3556919cb50cb906 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 14 Oct 2024 21:27:54 +0530 Subject: [PATCH] (refactor router.py ) - PR 3 - Ensure all functions under 100 lines (#6181) * add flake 8 check * split up litellm _acompletion * fix get model client * refactor use commong func to add metadata to kwargs * use common func to get timeout * re-use helper to _get_async_model_client * use _handle_mock_testing_rate_limit_error * fix docstring for _handle_mock_testing_rate_limit_error * fix function_with_retries * use helper for mock testing fallbacks * router - use 1 func for simple_shuffle * add doc string for simple_shuffle * use 1 function for filtering cooldown deployments * fix use common helper to _get_fallback_model_group_from_fallbacks --- .pre-commit-config.yaml | 5 + litellm/router.py | 916 ++++++------------ litellm/router_strategy/simple_shuffle.py | 96 ++ .../router_utils/fallback_event_handlers.py | 2 +- tests/local_testing/test_azure_openai.py | 1 + 5 files changed, 422 insertions(+), 598 deletions(-) create mode 100644 litellm/router_strategy/simple_shuffle.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4f93569b2..b8567fce7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,11 @@ repos: exclude: ^litellm/tests/|^litellm/proxy/tests/ additional_dependencies: [flake8-print] files: litellm/.*\.py + # - id: flake8 + # name: flake8 (router.py function length) + # files: ^litellm/router\.py$ + # args: [--max-function-length=40] + # # additional_dependencies: [flake8-functions] - repo: https://github.com/python-poetry/poetry rev: 1.8.0 hooks: diff --git a/litellm/router.py b/litellm/router.py index 845017465..c31536bd6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -46,6 +46,7 @@ from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 +from litellm.router_strategy.simple_shuffle import simple_shuffle from litellm.router_strategy.tag_based_routing import get_deployments_for_tag from litellm.router_utils.batch_utils import ( _get_router_metadata_variable_name, @@ -623,23 +624,10 @@ class Router: messages=messages, specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "api_base": deployment.get("litellm_params", {}).get("api_base"), - "model_info": deployment.get("model_info", {}), - } - ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() - kwargs["model_info"] = deployment.get("model_info", {}) model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs ) @@ -757,7 +745,6 @@ class Router: verbose_router_logger.debug( f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" ) - deployment = await self.async_get_available_deployment( model=model, messages=messages, @@ -767,53 +754,19 @@ class Router: # debug how often this deployment picked self._track_deployment_metrics(deployment=deployment) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - "api_base": deployment.get("litellm_params", {}).get("api_base"), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() - model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs and v is not None - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 - timeout = ( - data.get( - "timeout", None - ) # timeout set on litellm_params for this deployment - or data.get( - "request_timeout", None - ) # timeout set on litellm_params for this deployment - or self.timeout # timeout set on router - or kwargs.get( - "timeout", None - ) # this uses default_litellm_params when nothing is set + timeout: Optional[Union[float, int]] = self._get_timeout( + kwargs=kwargs, data=data ) _response = litellm.acompletion( @@ -882,6 +835,70 @@ class Router: self.fail_calls[model_name] += 1 raise e + def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None: + """ + Adds selected deployment, model_info and api_base to kwargs["metadata"] + + This is used in litellm logging callbacks + """ + kwargs.setdefault("metadata", {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + "api_base": deployment.get("litellm_params", {}).get("api_base"), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs and v is not None + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + def _get_async_openai_model_client(self, deployment: dict, kwargs: dict): + """ + Helper to get AsyncOpenAI or AsyncAzureOpenAI client that was created for the deployment + + The same OpenAI client is re-used to optimize latency / performance in production + + If dynamic api key is provided: + Do not re-use the client. Pass model_client=None. The OpenAI/ AzureOpenAI client will be recreated in the handler for the llm provider + """ + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + + return model_client + + def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]: + timeout = ( + data.get( + "timeout", None + ) # timeout set on litellm_params for this deployment + or data.get( + "request_timeout", None + ) # timeout set on litellm_params for this deployment + or self.timeout # timeout set on router + or kwargs.get( + "timeout", None + ) # this uses default_litellm_params when nothing is set + ) + + return timeout + async def abatch_completion( self, models: List[str], @@ -1218,36 +1235,13 @@ class Router: messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) data = deployment["litellm_params"].copy() - model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 @@ -1309,36 +1303,15 @@ class Router: messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 response = litellm.aimage_generation( @@ -1442,36 +1415,13 @@ class Router: messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) - data = deployment["litellm_params"].copy() - model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 response = litellm.atranscription( @@ -1640,46 +1590,18 @@ class Router: input=input, specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs and v is not None - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 - timeout = ( - data.get( - "timeout", None - ) # timeout set on litellm_params for this deployment - or self.timeout # timeout set on router - or kwargs.get( - "timeout", None - ) # this uses default_litellm_params when nothing is set + timeout: Optional[Union[float, int]] = self._get_timeout( + kwargs=kwargs, + data=data, ) response = await litellm.amoderation( @@ -1739,46 +1661,19 @@ class Router: model=model, specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs and v is not None - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 - timeout = ( - data.get( - "timeout", None - ) # timeout set on litellm_params for this deployment - or self.timeout # timeout set on router - or kwargs.get( - "timeout", None - ) # this uses default_litellm_params when nothing is set + timeout: Optional[Union[float, int]] = self._get_timeout( + kwargs=kwargs, + data=data, ) response = await litellm.arerank( @@ -1923,37 +1818,15 @@ class Router: messages=[{"role": "user", "content": prompt}], specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - "api_base": deployment.get("litellm_params", {}).get("api_base"), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 response = litellm.atext_completion( @@ -2042,37 +1915,15 @@ class Router: messages=[{"role": "user", "content": "default text"}], specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - "api_base": deployment.get("litellm_params", {}).get("api_base"), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 response = litellm.aadapter_completion( @@ -2151,22 +2002,9 @@ class Router: input=input, specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="sync" @@ -2247,37 +2085,13 @@ class Router: input=input, specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - "api_base": deployment.get("litellm_params", {}).get("api_base"), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 response = litellm.aembedding( @@ -2367,37 +2181,15 @@ class Router: messages=[{"role": "user", "content": "files-api-fake-text"}], specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - "api_base": deployment.get("litellm_params", {}).get("api_base"), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 ## REPLACE MODEL IN FILE WITH SELECTED DEPLOYMENT ## @@ -2515,19 +2307,10 @@ class Router: elif k == metadata_variable_name: kwargs[k].update(v) - potential_model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="async" + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, ) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client self.total_calls[model_name] += 1 ## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ## @@ -2890,48 +2673,22 @@ class Router: Try calling the function_with_retries If it fails after num_retries, fall back to another model group """ - mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) - mock_testing_context_fallbacks = kwargs.pop( - "mock_testing_context_fallbacks", None - ) - mock_testing_content_policy_fallbacks = kwargs.pop( - "mock_testing_content_policy_fallbacks", None - ) - model_group = kwargs.get("model") - fallbacks = kwargs.get("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.get( + model_group: Optional[str] = kwargs.get("model") + fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks) + context_window_fallbacks: Optional[List] = kwargs.get( "context_window_fallbacks", self.context_window_fallbacks ) - content_policy_fallbacks = kwargs.get( + content_policy_fallbacks: Optional[List] = kwargs.get( "content_policy_fallbacks", self.content_policy_fallbacks ) try: - if mock_testing_fallbacks is not None and mock_testing_fallbacks is True: - raise litellm.InternalServerError( - model=model_group, - llm_provider="", - message=f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}", - ) - elif ( - mock_testing_context_fallbacks is not None - and mock_testing_context_fallbacks is True - ): - raise litellm.ContextWindowExceededError( - model=model_group, - llm_provider="", - message=f"This is a mock exception for model={model_group}, to trigger a fallback. \ - Context_Window_Fallbacks={context_window_fallbacks}", - ) - elif ( - mock_testing_content_policy_fallbacks is not None - and mock_testing_content_policy_fallbacks is True - ): - raise litellm.ContentPolicyViolationError( - model=model_group, - llm_provider="", - message=f"This is a mock exception for model={model_group}, to trigger a fallback. \ - Context_Policy_Fallbacks={content_policy_fallbacks}", - ) + self._handle_mock_testing_fallbacks( + kwargs=kwargs, + model_group=model_group, + fallbacks=fallbacks, + context_window_fallbacks=context_window_fallbacks, + content_policy_fallbacks=content_policy_fallbacks, + ) response = await self.async_function_with_retries(*args, **kwargs) verbose_router_logger.debug(f"Async Response: {response}") @@ -2950,14 +2707,12 @@ class Router: verbose_router_logger.debug("Trying to fallback b/w models") if isinstance(e, litellm.ContextWindowExceededError): if context_window_fallbacks is not None: - fallback_model_group = None - for ( - item - ) in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] - if list(item.keys())[0] == model_group: - fallback_model_group = item[model_group] - break - + fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, + ) + ) if fallback_model_group is None: raise original_exception @@ -2985,14 +2740,12 @@ class Router: e.message += "\n{}".format(error_message) elif isinstance(e, litellm.ContentPolicyViolationError): if content_policy_fallbacks is not None: - fallback_model_group = None - for ( - item - ) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] - if list(item.keys())[0] == model_group: - fallback_model_group = item[model_group] - break - + fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, + ) + ) if fallback_model_group is None: raise original_exception @@ -3081,14 +2834,62 @@ class Router: raise original_exception + def _handle_mock_testing_fallbacks( + self, + kwargs: dict, + model_group: Optional[str] = None, + fallbacks: Optional[List] = None, + context_window_fallbacks: Optional[List] = None, + content_policy_fallbacks: Optional[List] = None, + ): + """ + Helper function to raise a litellm Error for mock testing purposes. + + Raises: + litellm.InternalServerError: when `mock_testing_fallbacks=True` passed in request params + litellm.ContextWindowExceededError: when `mock_testing_context_fallbacks=True` passed in request params + litellm.ContentPolicyViolationError: when `mock_testing_content_policy_fallbacks=True` passed in request params + """ + mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) + mock_testing_context_fallbacks = kwargs.pop( + "mock_testing_context_fallbacks", None + ) + mock_testing_content_policy_fallbacks = kwargs.pop( + "mock_testing_content_policy_fallbacks", None + ) + + if mock_testing_fallbacks is not None and mock_testing_fallbacks is True: + raise litellm.InternalServerError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}", + ) + elif ( + mock_testing_context_fallbacks is not None + and mock_testing_context_fallbacks is True + ): + raise litellm.ContextWindowExceededError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a fallback. \ + Context_Window_Fallbacks={context_window_fallbacks}", + ) + elif ( + mock_testing_content_policy_fallbacks is not None + and mock_testing_content_policy_fallbacks is True + ): + raise litellm.ContentPolicyViolationError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a fallback. \ + Context_Policy_Fallbacks={content_policy_fallbacks}", + ) + async def async_function_with_retries(self, *args, **kwargs): verbose_router_logger.debug( f"Inside async function with retries: args - {args}; kwargs - {kwargs}" ) original_function = kwargs.pop("original_function") - mock_testing_rate_limit_error = kwargs.pop( - "mock_testing_rate_limit_error", None - ) fallbacks = kwargs.pop("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.pop( "context_window_fallbacks", self.context_window_fallbacks @@ -3096,7 +2897,7 @@ class Router: content_policy_fallbacks = kwargs.pop( "content_policy_fallbacks", self.content_policy_fallbacks ) - model_group = kwargs.get("model") + model_group: Optional[str] = kwargs.get("model") num_retries = kwargs.pop("num_retries") ## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking @@ -3110,18 +2911,9 @@ class Router: f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}" ) try: - if ( - mock_testing_rate_limit_error is not None - and mock_testing_rate_limit_error is True - ): - verbose_router_logger.info( - "litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError." - ) - raise litellm.RateLimitError( - model=model_group, - llm_provider="", - message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", - ) + self._handle_mock_testing_rate_limit_error( + model_group=model_group, kwargs=kwargs + ) # if the function call is successful, no exception will be raised and we'll break out of the loop response = await self.make_call(original_function, *args, **kwargs) @@ -3222,6 +3014,31 @@ class Router: return response + def _handle_mock_testing_rate_limit_error( + self, kwargs: dict, model_group: Optional[str] = None + ): + """ + Helper function to raise a mock litellm.RateLimitError error for testing purposes. + + Raises: + litellm.RateLimitError error when `mock_testing_rate_limit_error=True` passed in request params + """ + mock_testing_rate_limit_error: Optional[bool] = kwargs.pop( + "mock_testing_rate_limit_error", None + ) + if ( + mock_testing_rate_limit_error is not None + and mock_testing_rate_limit_error is True + ): + verbose_router_logger.info( + f"litellm.router.py::_mock_rate_limit_error() - Raising mock RateLimitError for model={model_group}" + ) + raise litellm.RateLimitError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", + ) + def should_retry_this_error( self, error: Exception, @@ -3292,9 +3109,7 @@ class Router: Try calling the function_with_retries If it fails after num_retries, fall back to another model group """ - mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) - - model_group = kwargs.get("model") + model_group: Optional[str] = kwargs.get("model") fallbacks = kwargs.get("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.get( "context_window_fallbacks", self.context_window_fallbacks @@ -3304,11 +3119,13 @@ class Router: ) try: - if mock_testing_fallbacks is not None and mock_testing_fallbacks is True: - raise Exception( - f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}" - ) - + self._handle_mock_testing_fallbacks( + kwargs=kwargs, + model_group=model_group, + fallbacks=fallbacks, + context_window_fallbacks=context_window_fallbacks, + content_policy_fallbacks=content_policy_fallbacks, + ) response = self.function_with_retries(*args, **kwargs) return response except Exception as e: @@ -3329,12 +3146,12 @@ class Router: ): fallback_model_group = None - for ( - item - ) in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] - if list(item.keys())[0] == model_group: - fallback_model_group = item[model_group] - break + fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, + ) + ) if fallback_model_group is None: raise original_exception @@ -3351,14 +3168,12 @@ class Router: isinstance(e, litellm.ContentPolicyViolationError) and content_policy_fallbacks is not None ): - fallback_model_group = None - - for ( - item - ) in content_policy_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] - if list(item.keys())[0] == model_group: - fallback_model_group = item[model_group] - break + fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, + ) + ) if fallback_model_group is None: raise original_exception @@ -3406,6 +3221,31 @@ class Router: raise e raise original_exception + def _get_fallback_model_group_from_fallbacks( + self, + fallbacks: List[Dict[str, List[str]]], + model_group: Optional[str] = None, + ) -> Optional[List[str]]: + """ + Returns the list of fallback models to use for a given model group + + If no fallback model group is found, returns None + + Example: + fallbacks = [{"gpt-3.5-turbo": ["gpt-4"]}, {"gpt-4o": ["gpt-3.5-turbo"]}] + model_group = "gpt-3.5-turbo" + returns: ["gpt-4"] + """ + if model_group is None: + return None + + fallback_model_group: Optional[List[str]] = None + for item in fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] + if list(item.keys())[0] == model_group: + fallback_model_group = item[model_group] + break + return fallback_model_group + def _time_to_sleep_before_retry( self, e: Exception, @@ -3458,9 +3298,6 @@ class Router: f"Inside function with retries: args - {args}; kwargs - {kwargs}" ) original_function = kwargs.pop("original_function") - mock_testing_rate_limit_error = kwargs.pop( - "mock_testing_rate_limit_error", None - ) num_retries = kwargs.pop("num_retries") fallbacks = kwargs.pop("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.pop( @@ -3473,18 +3310,9 @@ class Router: try: # if the function call is successful, no exception will be raised and we'll break out of the loop - if ( - mock_testing_rate_limit_error is not None - and mock_testing_rate_limit_error is True - ): - verbose_router_logger.info( - "litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError." - ) - raise litellm.RateLimitError( - model=model_group, - llm_provider="", - message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", - ) + self._handle_mock_testing_rate_limit_error( + kwargs=kwargs, model_group=model_group + ) response = original_function(*args, **kwargs) return response except Exception as e: @@ -5272,23 +5100,16 @@ class Router: if isinstance(healthy_deployments, dict): return healthy_deployments - # filter out the deployments currently cooling down - deployments_to_remove = [] - # cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"] cooldown_deployments = await _async_get_cooldown_deployments( litellm_router_instance=self ) verbose_router_logger.debug( f"async cooldown deployments: {cooldown_deployments}" ) - # Find deployments in model_list whose model_id is cooling down - for deployment in healthy_deployments: - deployment_id = deployment["model_info"]["id"] - if deployment_id in cooldown_deployments: - deployments_to_remove.append(deployment) - # remove unhealthy deployments from healthy deployments - for deployment in deployments_to_remove: - healthy_deployments.remove(deployment) + healthy_deployments = self._filter_cooldown_deployments( + healthy_deployments=healthy_deployments, + cooldown_deployments=cooldown_deployments, + ) # filter pre-call checks _allowed_model_region = ( @@ -5353,78 +5174,11 @@ class Router: ) ) elif self.routing_strategy == "simple-shuffle": - # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm - - ############## Check if 'weight' param set for a weighted pick ################# - weight = ( - healthy_deployments[0].get("litellm_params").get("weight", None) + return simple_shuffle( + llm_router_instance=self, + healthy_deployments=healthy_deployments, + model=model, ) - if weight is not None: - # use weight-random pick if rpms provided - weights = [ - m["litellm_params"].get("weight", 0) - for m in healthy_deployments - ] - verbose_router_logger.debug(f"\nweight {weights}") - total_weight = sum(weights) - weights = [weight / total_weight for weight in weights] - verbose_router_logger.debug(f"\n weights {weights}") - # Perform weighted random pick - selected_index = random.choices( - range(len(weights)), weights=weights - )[0] - verbose_router_logger.debug(f"\n selected index, {selected_index}") - deployment = healthy_deployments[selected_index] - verbose_router_logger.info( - f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" - ) - return deployment or deployment[0] - ############## Check if we can do a RPM/TPM based weighted pick ################# - rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) - if rpm is not None: - # use weight-random pick if rpms provided - rpms = [ - m["litellm_params"].get("rpm", 0) for m in healthy_deployments - ] - verbose_router_logger.debug(f"\nrpms {rpms}") - total_rpm = sum(rpms) - weights = [rpm / total_rpm for rpm in rpms] - verbose_router_logger.debug(f"\n weights {weights}") - # Perform weighted random pick - selected_index = random.choices(range(len(rpms)), weights=weights)[ - 0 - ] - verbose_router_logger.debug(f"\n selected index, {selected_index}") - deployment = healthy_deployments[selected_index] - verbose_router_logger.info( - f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" - ) - return deployment or deployment[0] - ############## Check if we can do a RPM/TPM based weighted pick ################# - tpm = healthy_deployments[0].get("litellm_params").get("tpm", None) - if tpm is not None: - # use weight-random pick if rpms provided - tpms = [ - m["litellm_params"].get("tpm", 0) for m in healthy_deployments - ] - verbose_router_logger.debug(f"\ntpms {tpms}") - total_tpm = sum(tpms) - weights = [tpm / total_tpm for tpm in tpms] - verbose_router_logger.debug(f"\n weights {weights}") - # Perform weighted random pick - selected_index = random.choices(range(len(tpms)), weights=weights)[ - 0 - ] - verbose_router_logger.debug(f"\n selected index, {selected_index}") - deployment = healthy_deployments[selected_index] - verbose_router_logger.info( - f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" - ) - return deployment or deployment[0] - - ############## No RPM/TPM passed, we do a random pick ################# - item = random.choice(healthy_deployments) - return item or item[0] else: deployment = None if deployment is None: @@ -5489,19 +5243,11 @@ class Router: if isinstance(healthy_deployments, dict): return healthy_deployments - # filter out the deployments currently cooling down - deployments_to_remove = [] - # cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"] cooldown_deployments = _get_cooldown_deployments(litellm_router_instance=self) - verbose_router_logger.debug(f"cooldown deployments: {cooldown_deployments}") - # Find deployments in model_list whose model_id is cooling down - for deployment in healthy_deployments: - deployment_id = deployment["model_info"]["id"] - if deployment_id in cooldown_deployments: - deployments_to_remove.append(deployment) - # remove unhealthy deployments from healthy deployments - for deployment in deployments_to_remove: - healthy_deployments.remove(deployment) + healthy_deployments = self._filter_cooldown_deployments( + healthy_deployments=healthy_deployments, + cooldown_deployments=cooldown_deployments, + ) # filter pre-call checks if self.enable_pre_call_checks and messages is not None: @@ -5530,62 +5276,11 @@ class Router: elif self.routing_strategy == "simple-shuffle": # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm ############## Check 'weight' param set for weighted pick ################# - weight = healthy_deployments[0].get("litellm_params").get("weight", None) - if weight is not None: - # use weight-random pick if rpms provided - weights = [ - m["litellm_params"].get("weight", 0) for m in healthy_deployments - ] - verbose_router_logger.debug(f"\nweight {weights}") - total_weight = sum(weights) - weights = [weight / total_weight for weight in weights] - verbose_router_logger.debug(f"\n weights {weights}") - # Perform weighted random pick - selected_index = random.choices(range(len(weights)), weights=weights)[0] - verbose_router_logger.debug(f"\n selected index, {selected_index}") - deployment = healthy_deployments[selected_index] - verbose_router_logger.info( - f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" - ) - return deployment or deployment[0] - ############## Check if we can do a RPM/TPM based weighted pick ################# - rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) - if rpm is not None: - # use weight-random pick if rpms provided - rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments] - verbose_router_logger.debug(f"\nrpms {rpms}") - total_rpm = sum(rpms) - weights = [rpm / total_rpm for rpm in rpms] - verbose_router_logger.debug(f"\n weights {weights}") - # Perform weighted random pick - selected_index = random.choices(range(len(rpms)), weights=weights)[0] - verbose_router_logger.debug(f"\n selected index, {selected_index}") - deployment = healthy_deployments[selected_index] - verbose_router_logger.info( - f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" - ) - return deployment or deployment[0] - ############## Check if we can do a RPM/TPM based weighted pick ################# - tpm = healthy_deployments[0].get("litellm_params").get("tpm", None) - if tpm is not None: - # use weight-random pick if rpms provided - tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments] - verbose_router_logger.debug(f"\ntpms {tpms}") - total_tpm = sum(tpms) - weights = [tpm / total_tpm for tpm in tpms] - verbose_router_logger.debug(f"\n weights {weights}") - # Perform weighted random pick - selected_index = random.choices(range(len(tpms)), weights=weights)[0] - verbose_router_logger.debug(f"\n selected index, {selected_index}") - deployment = healthy_deployments[selected_index] - verbose_router_logger.info( - f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" - ) - return deployment or deployment[0] - - ############## No RPM/TPM passed, we do a random pick ################# - item = random.choice(healthy_deployments) - return item or item[0] + return simple_shuffle( + llm_router_instance=self, + healthy_deployments=healthy_deployments, + model=model, + ) elif ( self.routing_strategy == "latency-based-routing" and self.lowestlatency_logger is not None @@ -5636,6 +5331,33 @@ class Router: ) return deployment + def _filter_cooldown_deployments( + self, healthy_deployments: List[Dict], cooldown_deployments: List[str] + ) -> List[Dict]: + """ + Filters out the deployments currently cooling down from the list of healthy deployments + + Args: + healthy_deployments: List of healthy deployments + cooldown_deployments: List of model_ids cooling down. cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"] + + Returns: + List of healthy deployments + """ + # filter out the deployments currently cooling down + deployments_to_remove = [] + verbose_router_logger.debug(f"cooldown deployments: {cooldown_deployments}") + # Find deployments in model_list whose model_id is cooling down + for deployment in healthy_deployments: + deployment_id = deployment["model_info"]["id"] + if deployment_id in cooldown_deployments: + deployments_to_remove.append(deployment) + + # remove unhealthy deployments from healthy deployments + for deployment in deployments_to_remove: + healthy_deployments.remove(deployment) + return healthy_deployments + def _track_deployment_metrics(self, deployment, response=None): try: litellm_params = deployment["litellm_params"] diff --git a/litellm/router_strategy/simple_shuffle.py b/litellm/router_strategy/simple_shuffle.py new file mode 100644 index 000000000..da24c02f2 --- /dev/null +++ b/litellm/router_strategy/simple_shuffle.py @@ -0,0 +1,96 @@ +""" +Returns a random deployment from the list of healthy deployments. + +If weights are provided, it will return a deployment based on the weights. + +""" + +import random +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from litellm._logging import verbose_router_logger + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +def simple_shuffle( + llm_router_instance: LitellmRouter, + healthy_deployments: Union[List[Any], Dict[Any, Any]], + model: str, +) -> Dict: + """ + Returns a random deployment from the list of healthy deployments. + + If weights are provided, it will return a deployment based on the weights. + + If users pass `rpm` or `tpm`, we do a random weighted pick - based on `rpm`/`tpm`. + + Args: + llm_router_instance: LitellmRouter instance + healthy_deployments: List of healthy deployments + model: Model name + + Returns: + Dict: A single healthy deployment + """ + + ############## Check if 'weight' param set for a weighted pick ################# + weight = healthy_deployments[0].get("litellm_params").get("weight", None) + if weight is not None: + # use weight-random pick if rpms provided + weights = [m["litellm_params"].get("weight", 0) for m in healthy_deployments] + verbose_router_logger.debug(f"\nweight {weights}") + total_weight = sum(weights) + weights = [weight / total_weight for weight in weights] + verbose_router_logger.debug(f"\n weights {weights}") + # Perform weighted random pick + selected_index = random.choices(range(len(weights)), weights=weights)[0] + verbose_router_logger.debug(f"\n selected index, {selected_index}") + deployment = healthy_deployments[selected_index] + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}" + ) + return deployment or deployment[0] + ############## Check if we can do a RPM/TPM based weighted pick ################# + rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) + if rpm is not None: + # use weight-random pick if rpms provided + rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments] + verbose_router_logger.debug(f"\nrpms {rpms}") + total_rpm = sum(rpms) + weights = [rpm / total_rpm for rpm in rpms] + verbose_router_logger.debug(f"\n weights {weights}") + # Perform weighted random pick + selected_index = random.choices(range(len(rpms)), weights=weights)[0] + verbose_router_logger.debug(f"\n selected index, {selected_index}") + deployment = healthy_deployments[selected_index] + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}" + ) + return deployment or deployment[0] + ############## Check if we can do a RPM/TPM based weighted pick ################# + tpm = healthy_deployments[0].get("litellm_params").get("tpm", None) + if tpm is not None: + # use weight-random pick if rpms provided + tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments] + verbose_router_logger.debug(f"\ntpms {tpms}") + total_tpm = sum(tpms) + weights = [tpm / total_tpm for tpm in tpms] + verbose_router_logger.debug(f"\n weights {weights}") + # Perform weighted random pick + selected_index = random.choices(range(len(tpms)), weights=weights)[0] + verbose_router_logger.debug(f"\n selected index, {selected_index}") + deployment = healthy_deployments[selected_index] + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}" + ) + return deployment or deployment[0] + + ############## No RPM/TPM passed, we do a random pick ################# + item = random.choice(healthy_deployments) + return item or item[0] diff --git a/litellm/router_utils/fallback_event_handlers.py b/litellm/router_utils/fallback_event_handlers.py index 9aab5416f..84495b5a0 100644 --- a/litellm/router_utils/fallback_event_handlers.py +++ b/litellm/router_utils/fallback_event_handlers.py @@ -66,7 +66,7 @@ def run_sync_fallback( **kwargs, ) -> Any: """ - Iterate through the model groups and try calling that deployment. + Iterate through the fallback model groups and try calling each fallback deployment. """ error_from_fallbacks = original_exception for mg in fallback_model_group: diff --git a/tests/local_testing/test_azure_openai.py b/tests/local_testing/test_azure_openai.py index 9972f2833..e82419c17 100644 --- a/tests/local_testing/test_azure_openai.py +++ b/tests/local_testing/test_azure_openai.py @@ -73,6 +73,7 @@ async def test_azure_tenant_id_auth(respx_mock: MockRouter): ], created=int(datetime.now().timestamp()), ) + litellm.set_verbose = True mock_request = respx_mock.post(url__regex=r".*/chat/completions.*").mock( return_value=httpx.Response(200, json=obj.model_dump(mode="json")) )