fix(router.py): fix cooldown logic for usage-based-routing-v2 pre-call-checks

This commit is contained in:
Krrish Dholakia 2024-05-31 21:32:01 -07:00
parent f3c37f487a
commit e49325b234
3 changed files with 45 additions and 10 deletions

View file

@ -430,6 +430,10 @@ def mock_completion(
model=model, # type: ignore model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
) )
time_delay = kwargs.get("mock_delay", None)
if time_delay is not None:
time.sleep(time_delay)
model_response = ModelResponse(stream=stream) model_response = ModelResponse(stream=stream)
if stream is True: if stream is True:
# don't try to access stream object, # don't try to access stream object,
@ -880,6 +884,7 @@ def completion(
mock_response=mock_response, mock_response=mock_response,
logging=logging, logging=logging,
acompletion=acompletion, acompletion=acompletion,
mock_delay=kwargs.get("mock_delay", None),
) )
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs

View file

@ -631,7 +631,6 @@ class Router:
kwargs=kwargs, kwargs=kwargs,
client_type="max_parallel_requests", client_type="max_parallel_requests",
) )
if rpm_semaphore is not None and isinstance( if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore rpm_semaphore, asyncio.Semaphore
): ):
@ -1875,6 +1874,7 @@ class Router:
error=e, error=e,
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks, context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
) )
# decides how long to sleep before retry # decides how long to sleep before retry
@ -1884,7 +1884,6 @@ class Router:
num_retries=num_retries, num_retries=num_retries,
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
) )
# sleeps for the length of the timeout # sleeps for the length of the timeout
await asyncio.sleep(_timeout) await asyncio.sleep(_timeout)
@ -1929,6 +1928,7 @@ class Router:
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
) )
await asyncio.sleep(_timeout) await asyncio.sleep(_timeout)
try: try:
cooldown_deployments = await self._async_get_cooldown_deployments() cooldown_deployments = await self._async_get_cooldown_deployments()
original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}" original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}"
@ -1941,6 +1941,7 @@ class Router:
error: Exception, error: Exception,
healthy_deployments: Optional[List] = None, healthy_deployments: Optional[List] = None,
context_window_fallbacks: Optional[List] = None, context_window_fallbacks: Optional[List] = None,
regular_fallbacks: Optional[List] = None,
): ):
""" """
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None 1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
@ -1957,7 +1958,7 @@ class Router:
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if ( if (
isinstance(error, litellm.ContextWindowExceededError) isinstance(error, litellm.ContextWindowExceededError)
and context_window_fallbacks is None and context_window_fallbacks is not None
): ):
raise error raise error
@ -1965,7 +1966,11 @@ class Router:
if isinstance(error, openai.RateLimitError) or isinstance( if isinstance(error, openai.RateLimitError) or isinstance(
error, openai.AuthenticationError error, openai.AuthenticationError
): ):
if _num_healthy_deployments <= 0: if (
_num_healthy_deployments <= 0
and regular_fallbacks is not None
and len(regular_fallbacks) > 0
):
raise error raise error
return True return True
@ -2140,6 +2145,7 @@ class Router:
error=e, error=e,
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks, context_window_fallbacks=context_window_fallbacks,
fallbacks=fallbacks,
) )
# decides how long to sleep before retry # decides how long to sleep before retry
@ -2348,7 +2354,7 @@ class Router:
the exception is not one that should be immediately retried (e.g. 401) the exception is not one that should be immediately retried (e.g. 401)
""" """
args = locals()
if deployment is None: if deployment is None:
return return
@ -2519,7 +2525,17 @@ class Router:
""" """
for _callback in litellm.callbacks: for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger): if isinstance(_callback, CustomLogger):
response = await _callback.async_pre_call_check(deployment) try:
response = await _callback.async_pre_call_check(deployment)
except litellm.RateLimitError as e:
self._set_cooldown_deployments(
exception_status=e.status_code,
deployment=deployment["model_info"]["id"],
time_to_cooldown=self.cooldown_time,
)
raise e
except Exception as e:
raise e
def set_client(self, model: dict): def set_client(self, model: dict):
""" """

View file

@ -77,11 +77,13 @@ async def test_scheduler_prioritized_requests(p0, p1):
assert await scheduler.peek(id="10", model_name="gpt-3.5-turbo") == False assert await scheduler.peek(id="10", model_name="gpt-3.5-turbo") == False
@pytest.mark.parametrize("p0, p1", [(0, 0), (0, 1), (1, 0)]) @pytest.mark.parametrize("p0, p1", [(0, 1)]) # (0, 0), (1, 0)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scheduler_prioritized_requests_mock_response(p0, p1): async def test_scheduler_prioritized_requests_mock_response(p0, p1):
""" """
2 requests for same model group 2 requests for same model group
if model is at rate limit, ensure the higher priority request gets done first
""" """
scheduler = Scheduler() scheduler = Scheduler()
@ -96,12 +98,19 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1):
}, },
}, },
], ],
timeout=2, timeout=10,
num_retries=3,
cooldown_time=5,
routing_strategy="usage-based-routing-v2", routing_strategy="usage-based-routing-v2",
) )
scheduler.update_variables(llm_router=router) scheduler.update_variables(llm_router=router)
await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey!"}],
)
async def _make_prioritized_call(flow_item: FlowItem): async def _make_prioritized_call(flow_item: FlowItem):
## POLL QUEUE ## POLL QUEUE
default_timeout = router.timeout default_timeout = router.timeout
@ -118,6 +127,7 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1):
make_request = await scheduler.poll( make_request = await scheduler.poll(
id=flow_item.request_id, model_name=flow_item.model_name id=flow_item.request_id, model_name=flow_item.model_name
) )
print(f"make_request={make_request}, priority={flow_item.priority}")
if make_request: ## IF TRUE -> MAKE REQUEST if make_request: ## IF TRUE -> MAKE REQUEST
break break
else: ## ELSE -> loop till default_timeout else: ## ELSE -> loop till default_timeout
@ -131,7 +141,8 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1):
messages=[{"role": "user", "content": "Hey!"}], messages=[{"role": "user", "content": "Hey!"}],
) )
except Exception as e: except Exception as e:
return flow_item.priority, flow_item.request_id, "Error occurred" print("Received error - {}".format(str(e)))
return flow_item.priority, flow_item.request_id, time.time()
return flow_item.priority, flow_item.request_id, time.time() return flow_item.priority, flow_item.request_id, time.time()
@ -159,7 +170,10 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1):
print(f"Received response: {result}") print(f"Received response: {result}")
print(f"responses: {completed_responses}") print(f"responses: {completed_responses}")
assert ( assert (
completed_responses[0][0] == 0 completed_responses[0][0] == 0
) # assert higher priority request got done first ) # assert higher priority request got done first
assert isinstance(completed_responses[1][2], str) # 2nd request errored out assert (
completed_responses[0][2] < completed_responses[1][2]
) # higher priority request tried first