mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(router.py): fix cooldown logic for usage-based-routing-v2 pre-call-checks
This commit is contained in:
parent
f3c37f487a
commit
e49325b234
3 changed files with 45 additions and 10 deletions
|
@ -430,6 +430,10 @@ def mock_completion(
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
)
|
)
|
||||||
|
time_delay = kwargs.get("mock_delay", None)
|
||||||
|
if time_delay is not None:
|
||||||
|
time.sleep(time_delay)
|
||||||
|
|
||||||
model_response = ModelResponse(stream=stream)
|
model_response = ModelResponse(stream=stream)
|
||||||
if stream is True:
|
if stream is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
@ -880,6 +884,7 @@ def completion(
|
||||||
mock_response=mock_response,
|
mock_response=mock_response,
|
||||||
logging=logging,
|
logging=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
|
mock_delay=kwargs.get("mock_delay", None),
|
||||||
)
|
)
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
|
|
|
@ -631,7 +631,6 @@ class Router:
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
client_type="max_parallel_requests",
|
client_type="max_parallel_requests",
|
||||||
)
|
)
|
||||||
|
|
||||||
if rpm_semaphore is not None and isinstance(
|
if rpm_semaphore is not None and isinstance(
|
||||||
rpm_semaphore, asyncio.Semaphore
|
rpm_semaphore, asyncio.Semaphore
|
||||||
):
|
):
|
||||||
|
@ -1875,6 +1874,7 @@ class Router:
|
||||||
error=e,
|
error=e,
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
context_window_fallbacks=context_window_fallbacks,
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
|
regular_fallbacks=fallbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# decides how long to sleep before retry
|
# decides how long to sleep before retry
|
||||||
|
@ -1884,7 +1884,6 @@ class Router:
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
)
|
)
|
||||||
|
|
||||||
# sleeps for the length of the timeout
|
# sleeps for the length of the timeout
|
||||||
await asyncio.sleep(_timeout)
|
await asyncio.sleep(_timeout)
|
||||||
|
|
||||||
|
@ -1929,6 +1928,7 @@ class Router:
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(_timeout)
|
await asyncio.sleep(_timeout)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cooldown_deployments = await self._async_get_cooldown_deployments()
|
cooldown_deployments = await self._async_get_cooldown_deployments()
|
||||||
original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}"
|
original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}"
|
||||||
|
@ -1941,6 +1941,7 @@ class Router:
|
||||||
error: Exception,
|
error: Exception,
|
||||||
healthy_deployments: Optional[List] = None,
|
healthy_deployments: Optional[List] = None,
|
||||||
context_window_fallbacks: Optional[List] = None,
|
context_window_fallbacks: Optional[List] = None,
|
||||||
|
regular_fallbacks: Optional[List] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
|
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
|
||||||
|
@ -1957,7 +1958,7 @@ class Router:
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||||
if (
|
if (
|
||||||
isinstance(error, litellm.ContextWindowExceededError)
|
isinstance(error, litellm.ContextWindowExceededError)
|
||||||
and context_window_fallbacks is None
|
and context_window_fallbacks is not None
|
||||||
):
|
):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
@ -1965,7 +1966,11 @@ class Router:
|
||||||
if isinstance(error, openai.RateLimitError) or isinstance(
|
if isinstance(error, openai.RateLimitError) or isinstance(
|
||||||
error, openai.AuthenticationError
|
error, openai.AuthenticationError
|
||||||
):
|
):
|
||||||
if _num_healthy_deployments <= 0:
|
if (
|
||||||
|
_num_healthy_deployments <= 0
|
||||||
|
and regular_fallbacks is not None
|
||||||
|
and len(regular_fallbacks) > 0
|
||||||
|
):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -2140,6 +2145,7 @@ class Router:
|
||||||
error=e,
|
error=e,
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
context_window_fallbacks=context_window_fallbacks,
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
|
fallbacks=fallbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# decides how long to sleep before retry
|
# decides how long to sleep before retry
|
||||||
|
@ -2348,7 +2354,7 @@ class Router:
|
||||||
|
|
||||||
the exception is not one that should be immediately retried (e.g. 401)
|
the exception is not one that should be immediately retried (e.g. 401)
|
||||||
"""
|
"""
|
||||||
args = locals()
|
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -2519,7 +2525,17 @@ class Router:
|
||||||
"""
|
"""
|
||||||
for _callback in litellm.callbacks:
|
for _callback in litellm.callbacks:
|
||||||
if isinstance(_callback, CustomLogger):
|
if isinstance(_callback, CustomLogger):
|
||||||
response = await _callback.async_pre_call_check(deployment)
|
try:
|
||||||
|
response = await _callback.async_pre_call_check(deployment)
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
self._set_cooldown_deployments(
|
||||||
|
exception_status=e.status_code,
|
||||||
|
deployment=deployment["model_info"]["id"],
|
||||||
|
time_to_cooldown=self.cooldown_time,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def set_client(self, model: dict):
|
def set_client(self, model: dict):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -77,11 +77,13 @@ async def test_scheduler_prioritized_requests(p0, p1):
|
||||||
assert await scheduler.peek(id="10", model_name="gpt-3.5-turbo") == False
|
assert await scheduler.peek(id="10", model_name="gpt-3.5-turbo") == False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("p0, p1", [(0, 0), (0, 1), (1, 0)])
|
@pytest.mark.parametrize("p0, p1", [(0, 1)]) # (0, 0), (1, 0)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scheduler_prioritized_requests_mock_response(p0, p1):
|
async def test_scheduler_prioritized_requests_mock_response(p0, p1):
|
||||||
"""
|
"""
|
||||||
2 requests for same model group
|
2 requests for same model group
|
||||||
|
|
||||||
|
if model is at rate limit, ensure the higher priority request gets done first
|
||||||
"""
|
"""
|
||||||
scheduler = Scheduler()
|
scheduler = Scheduler()
|
||||||
|
|
||||||
|
@ -96,12 +98,19 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1):
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
timeout=2,
|
timeout=10,
|
||||||
|
num_retries=3,
|
||||||
|
cooldown_time=5,
|
||||||
routing_strategy="usage-based-routing-v2",
|
routing_strategy="usage-based-routing-v2",
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler.update_variables(llm_router=router)
|
scheduler.update_variables(llm_router=router)
|
||||||
|
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey!"}],
|
||||||
|
)
|
||||||
|
|
||||||
async def _make_prioritized_call(flow_item: FlowItem):
|
async def _make_prioritized_call(flow_item: FlowItem):
|
||||||
## POLL QUEUE
|
## POLL QUEUE
|
||||||
default_timeout = router.timeout
|
default_timeout = router.timeout
|
||||||
|
@ -118,6 +127,7 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1):
|
||||||
make_request = await scheduler.poll(
|
make_request = await scheduler.poll(
|
||||||
id=flow_item.request_id, model_name=flow_item.model_name
|
id=flow_item.request_id, model_name=flow_item.model_name
|
||||||
)
|
)
|
||||||
|
print(f"make_request={make_request}, priority={flow_item.priority}")
|
||||||
if make_request: ## IF TRUE -> MAKE REQUEST
|
if make_request: ## IF TRUE -> MAKE REQUEST
|
||||||
break
|
break
|
||||||
else: ## ELSE -> loop till default_timeout
|
else: ## ELSE -> loop till default_timeout
|
||||||
|
@ -131,7 +141,8 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1):
|
||||||
messages=[{"role": "user", "content": "Hey!"}],
|
messages=[{"role": "user", "content": "Hey!"}],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return flow_item.priority, flow_item.request_id, "Error occurred"
|
print("Received error - {}".format(str(e)))
|
||||||
|
return flow_item.priority, flow_item.request_id, time.time()
|
||||||
|
|
||||||
return flow_item.priority, flow_item.request_id, time.time()
|
return flow_item.priority, flow_item.request_id, time.time()
|
||||||
|
|
||||||
|
@ -159,7 +170,10 @@ async def test_scheduler_prioritized_requests_mock_response(p0, p1):
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
|
|
||||||
print(f"responses: {completed_responses}")
|
print(f"responses: {completed_responses}")
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
completed_responses[0][0] == 0
|
completed_responses[0][0] == 0
|
||||||
) # assert higher priority request got done first
|
) # assert higher priority request got done first
|
||||||
assert isinstance(completed_responses[1][2], str) # 2nd request errored out
|
assert (
|
||||||
|
completed_responses[0][2] < completed_responses[1][2]
|
||||||
|
) # higher priority request tried first
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue