From e49325b2348a9013ad4553131cd39cfb8aa36565 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 31 May 2024 21:32:01 -0700 Subject: [PATCH] fix(router.py): fix cooldown logic for usage-based-routing-v2 pre-call-checks --- litellm/main.py | 5 +++++ litellm/router.py | 28 ++++++++++++++++++++++------ litellm/tests/test_scheduler.py | 22 ++++++++++++++++++---- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 525a39d68..32cad89af 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -430,6 +430,10 @@ def mock_completion( model=model, # type: ignore request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), ) + time_delay = kwargs.get("mock_delay", None) + if time_delay is not None: + time.sleep(time_delay) + model_response = ModelResponse(stream=stream) if stream is True: # don't try to access stream object, @@ -880,6 +884,7 @@ def completion( mock_response=mock_response, logging=logging, acompletion=acompletion, + mock_delay=kwargs.get("mock_delay", None), ) if custom_llm_provider == "azure": # azure configs diff --git a/litellm/router.py b/litellm/router.py index d7ff84877..c0abad448 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -631,7 +631,6 @@ class Router: kwargs=kwargs, client_type="max_parallel_requests", ) - if rpm_semaphore is not None and isinstance( rpm_semaphore, asyncio.Semaphore ): @@ -1875,6 +1874,7 @@ class Router: error=e, healthy_deployments=_healthy_deployments, context_window_fallbacks=context_window_fallbacks, + regular_fallbacks=fallbacks, ) # decides how long to sleep before retry @@ -1884,7 +1884,6 @@ class Router: num_retries=num_retries, healthy_deployments=_healthy_deployments, ) - # sleeps for the length of the timeout await asyncio.sleep(_timeout) @@ -1929,6 +1928,7 @@ class Router: healthy_deployments=_healthy_deployments, ) await asyncio.sleep(_timeout) + try: cooldown_deployments = await self._async_get_cooldown_deployments() original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}" @@ -1941,6 +1941,7 @@ class Router: error: Exception, healthy_deployments: Optional[List] = None, context_window_fallbacks: Optional[List] = None, + regular_fallbacks: Optional[List] = None, ): """ 1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None @@ -1957,7 +1958,7 @@ class Router: ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error if ( isinstance(error, litellm.ContextWindowExceededError) - and context_window_fallbacks is None + and context_window_fallbacks is not None ): raise error @@ -1965,7 +1966,11 @@ class Router: if isinstance(error, openai.RateLimitError) or isinstance( error, openai.AuthenticationError ): - if _num_healthy_deployments <= 0: + if ( + _num_healthy_deployments <= 0 + and regular_fallbacks is not None + and len(regular_fallbacks) > 0 + ): raise error return True @@ -2140,6 +2145,7 @@ class Router: error=e, healthy_deployments=_healthy_deployments, context_window_fallbacks=context_window_fallbacks, + fallbacks=fallbacks, ) # decides how long to sleep before retry @@ -2348,7 +2354,7 @@ class Router: the exception is not one that should be immediately retried (e.g. 401) """ - args = locals() + if deployment is None: return @@ -2519,7 +2525,17 @@ class Router: """ for _callback in litellm.callbacks: if isinstance(_callback, CustomLogger): - response = await _callback.async_pre_call_check(deployment) + try: + response = await _callback.async_pre_call_check(deployment) + except litellm.RateLimitError as e: + self._set_cooldown_deployments( + exception_status=e.status_code, + deployment=deployment["model_info"]["id"], + time_to_cooldown=self.cooldown_time, + ) + raise e + except Exception as e: + raise e def set_client(self, model: dict): """ diff --git a/litellm/tests/test_scheduler.py b/litellm/tests/test_scheduler.py index 2e48eab3c..bba06d587 100644 --- a/litellm/tests/test_scheduler.py +++ b/litellm/tests/test_scheduler.py @@ -77,11 +77,13 @@ async def test_scheduler_prioritized_requests(p0, p1): assert await scheduler.peek(id="10", model_name="gpt-3.5-turbo") == False -@pytest.mark.parametrize("p0, p1", [(0, 0), (0, 1), (1, 0)]) +@pytest.mark.parametrize("p0, p1", [(0, 1)]) # (0, 0), (1, 0) @pytest.mark.asyncio async def test_scheduler_prioritized_requests_mock_response(p0, p1): """ 2 requests for same model group + + if model is at rate limit, ensure the higher priority request gets done first """ scheduler = Scheduler() @@ -96,12 +98,19 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1): }, }, ], - timeout=2, + timeout=10, + num_retries=3, + cooldown_time=5, routing_strategy="usage-based-routing-v2", ) scheduler.update_variables(llm_router=router) + await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey!"}], + ) + async def _make_prioritized_call(flow_item: FlowItem): ## POLL QUEUE default_timeout = router.timeout @@ -118,6 +127,7 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1): make_request = await scheduler.poll( id=flow_item.request_id, model_name=flow_item.model_name ) + print(f"make_request={make_request}, priority={flow_item.priority}") if make_request: ## IF TRUE -> MAKE REQUEST break else: ## ELSE -> loop till default_timeout @@ -131,7 +141,8 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1): messages=[{"role": "user", "content": "Hey!"}], ) except Exception as e: - return flow_item.priority, flow_item.request_id, "Error occurred" + print("Received error - {}".format(str(e))) + return flow_item.priority, flow_item.request_id, time.time() return flow_item.priority, flow_item.request_id, time.time() @@ -159,7 +170,10 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1): print(f"Received response: {result}") print(f"responses: {completed_responses}") + assert ( completed_responses[0][0] == 0 ) # assert higher priority request got done first - assert isinstance(completed_responses[1][2], str) # 2nd request errored out + assert ( + completed_responses[0][2] < completed_responses[1][2] + ) # higher priority request tried first