diff --git a/litellm/router.py b/litellm/router.py index e0c124c0c..deb77162d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -522,9 +522,9 @@ class Router: messages=messages, specific_deployment=kwargs.pop("specific_deployment", None), ) - if self.set_verbose == True and self.debug_level == "DEBUG": - # debug how often this deployment picked - self._print_deployment_metrics(deployment=deployment) + + # debug how often this deployment picked + self._print_deployment_metrics(deployment=deployment) kwargs.setdefault("metadata", {}).update( { @@ -582,9 +582,9 @@ class Router: verbose_router_logger.info( f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m" ) - if self.set_verbose == True and self.debug_level == "DEBUG": - # debug how often this deployment picked - self._print_deployment_metrics(deployment=deployment, response=response) + # debug how often this deployment picked + self._print_deployment_metrics(deployment=deployment, response=response) + return response except Exception as e: verbose_router_logger.info( @@ -2360,6 +2360,8 @@ class Router: except Exception as e: return _returned_deployments + _context_window_error = False + _rate_limit_error = False for idx, deployment in enumerate(_returned_deployments): # see if we have the info for this model try: @@ -2384,19 +2386,48 @@ class Router: and input_tokens > model_info["max_input_tokens"] ): invalid_model_indices.append(idx) + _context_window_error = True + continue + + ## TPM/RPM CHECK ## + _litellm_params = deployment.get("litellm_params", {}) + _model_id = deployment.get("model_info", {}).get("id", "") + + if ( + isinstance(_litellm_params, dict) + and _litellm_params.get("rpm", None) is not None + ): + if ( + isinstance(_litellm_params["rpm"], int) + and _model_id in self.deployment_stats + and _litellm_params["rpm"] + <= self.deployment_stats[_model_id]["num_requests"] + ): + invalid_model_indices.append(idx) + _rate_limit_error = True + continue if len(invalid_model_indices) == len(_returned_deployments): """ - - no healthy deployments available b/c context window checks + - no healthy deployments available b/c context window checks or rate limit error + + - First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check) """ - raise litellm.ContextWindowExceededError( - message="Context Window exceeded for given call", - model=model, - llm_provider="", - response=httpx.Response( - status_code=400, request=httpx.Request("GET", "https://example.com") - ), - ) + + if _rate_limit_error == True: # allow generic fallback logic to take place + raise ValueError( + f"No deployments available for selected model, passed model={model}" + ) + elif _context_window_error == True: + raise litellm.ContextWindowExceededError( + message="Context Window exceeded for given call", + model=model, + llm_provider="", + response=httpx.Response( + status_code=400, + request=httpx.Request("GET", "https://example.com"), + ), + ) if len(invalid_model_indices) > 0: for idx in reversed(invalid_model_indices): _returned_deployments.pop(idx) @@ -2606,13 +2637,16 @@ class Router: "num_successes": 1, "avg_latency": response_ms, } - from pprint import pformat + if self.set_verbose == True and self.debug_level == "DEBUG": + from pprint import pformat - # Assuming self.deployment_stats is your dictionary - formatted_stats = pformat(self.deployment_stats) + # Assuming self.deployment_stats is your dictionary + formatted_stats = pformat(self.deployment_stats) - # Assuming verbose_router_logger is your logger - verbose_router_logger.info("self.deployment_stats: \n%s", formatted_stats) + # Assuming verbose_router_logger is your logger + verbose_router_logger.info( + "self.deployment_stats: \n%s", formatted_stats + ) except Exception as e: verbose_router_logger.error(f"Error in _print_deployment_metrics: {str(e)}") diff --git a/tests/test_ratelimit.py b/tests/test_ratelimit.py index 565f4c3d3..f86fb02d3 100644 --- a/tests/test_ratelimit.py +++ b/tests/test_ratelimit.py @@ -4,6 +4,14 @@ import os import pytest import random from typing import Any +import sys +from dotenv import load_dotenv + +load_dotenv() + +sys.path.insert( + 0, os.path.abspath("../") +) # Adds the parent directory to the system path from pydantic import BaseModel from litellm import utils, Router @@ -35,6 +43,7 @@ def router_factory(): return Router( model_list=model_list, routing_strategy=routing_strategy, + enable_pre_call_checks=True, debug_level="DEBUG", ) @@ -115,15 +124,23 @@ def test_rate_limit( ExpectNoException: Signfies that no other error has happened. A NOP """ # Can send more messages then we're going to; so don't expect a rate limit error + args = locals() + print(f"args: {args}") expected_exception = ( ExpectNoException if num_try_send <= num_allowed_send else ValueError ) + # if ( + # num_try_send > num_allowed_send and sync_mode == False + # ): # async calls are made simultaneously - the check for collision would need to happen before the router call + # return + list_of_messages = generate_list_of_messages(max(num_try_send, num_allowed_send)) rpm, tpm = calculate_limits(list_of_messages[:num_allowed_send]) list_of_messages = list_of_messages[:num_try_send] - router = router_factory(rpm, tpm, routing_strategy) + router: Router = router_factory(rpm, tpm, routing_strategy) + print(f"router: {router.model_list}") with pytest.raises(expected_exception) as excinfo: # asserts correct type raised if sync_mode: results = sync_call(router, list_of_messages)