forked from phoenix/litellm-mirror
fix(router.py): make router async calls coroutine safe
uses pre-call checks to check if a call is below it's rpm limit, works even if multiple async calls are made simultaneously
This commit is contained in:
parent
a101591f74
commit
0d1cca9aa0
2 changed files with 72 additions and 21 deletions
|
@ -522,9 +522,9 @@ class Router:
|
|||
messages=messages,
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
if self.set_verbose == True and self.debug_level == "DEBUG":
|
||||
# debug how often this deployment picked
|
||||
self._print_deployment_metrics(deployment=deployment)
|
||||
|
||||
# debug how often this deployment picked
|
||||
self._print_deployment_metrics(deployment=deployment)
|
||||
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{
|
||||
|
@ -582,9 +582,9 @@ class Router:
|
|||
verbose_router_logger.info(
|
||||
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
if self.set_verbose == True and self.debug_level == "DEBUG":
|
||||
# debug how often this deployment picked
|
||||
self._print_deployment_metrics(deployment=deployment, response=response)
|
||||
# debug how often this deployment picked
|
||||
self._print_deployment_metrics(deployment=deployment, response=response)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
|
@ -2360,6 +2360,8 @@ class Router:
|
|||
except Exception as e:
|
||||
return _returned_deployments
|
||||
|
||||
_context_window_error = False
|
||||
_rate_limit_error = False
|
||||
for idx, deployment in enumerate(_returned_deployments):
|
||||
# see if we have the info for this model
|
||||
try:
|
||||
|
@ -2384,19 +2386,48 @@ class Router:
|
|||
and input_tokens > model_info["max_input_tokens"]
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
_context_window_error = True
|
||||
continue
|
||||
|
||||
## TPM/RPM CHECK ##
|
||||
_litellm_params = deployment.get("litellm_params", {})
|
||||
_model_id = deployment.get("model_info", {}).get("id", "")
|
||||
|
||||
if (
|
||||
isinstance(_litellm_params, dict)
|
||||
and _litellm_params.get("rpm", None) is not None
|
||||
):
|
||||
if (
|
||||
isinstance(_litellm_params["rpm"], int)
|
||||
and _model_id in self.deployment_stats
|
||||
and _litellm_params["rpm"]
|
||||
<= self.deployment_stats[_model_id]["num_requests"]
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
_rate_limit_error = True
|
||||
continue
|
||||
|
||||
if len(invalid_model_indices) == len(_returned_deployments):
|
||||
"""
|
||||
- no healthy deployments available b/c context window checks
|
||||
- no healthy deployments available b/c context window checks or rate limit error
|
||||
|
||||
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
|
||||
"""
|
||||
raise litellm.ContextWindowExceededError(
|
||||
message="Context Window exceeded for given call",
|
||||
model=model,
|
||||
llm_provider="",
|
||||
response=httpx.Response(
|
||||
status_code=400, request=httpx.Request("GET", "https://example.com")
|
||||
),
|
||||
)
|
||||
|
||||
if _rate_limit_error == True: # allow generic fallback logic to take place
|
||||
raise ValueError(
|
||||
f"No deployments available for selected model, passed model={model}"
|
||||
)
|
||||
elif _context_window_error == True:
|
||||
raise litellm.ContextWindowExceededError(
|
||||
message="Context Window exceeded for given call",
|
||||
model=model,
|
||||
llm_provider="",
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
request=httpx.Request("GET", "https://example.com"),
|
||||
),
|
||||
)
|
||||
if len(invalid_model_indices) > 0:
|
||||
for idx in reversed(invalid_model_indices):
|
||||
_returned_deployments.pop(idx)
|
||||
|
@ -2606,13 +2637,16 @@ class Router:
|
|||
"num_successes": 1,
|
||||
"avg_latency": response_ms,
|
||||
}
|
||||
from pprint import pformat
|
||||
if self.set_verbose == True and self.debug_level == "DEBUG":
|
||||
from pprint import pformat
|
||||
|
||||
# Assuming self.deployment_stats is your dictionary
|
||||
formatted_stats = pformat(self.deployment_stats)
|
||||
# Assuming self.deployment_stats is your dictionary
|
||||
formatted_stats = pformat(self.deployment_stats)
|
||||
|
||||
# Assuming verbose_router_logger is your logger
|
||||
verbose_router_logger.info("self.deployment_stats: \n%s", formatted_stats)
|
||||
# Assuming verbose_router_logger is your logger
|
||||
verbose_router_logger.info(
|
||||
"self.deployment_stats: \n%s", formatted_stats
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(f"Error in _print_deployment_metrics: {str(e)}")
|
||||
|
||||
|
|
|
@ -4,6 +4,14 @@ import os
|
|||
import pytest
|
||||
import random
|
||||
from typing import Any
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from pydantic import BaseModel
|
||||
from litellm import utils, Router
|
||||
|
@ -35,6 +43,7 @@ def router_factory():
|
|||
return Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=routing_strategy,
|
||||
enable_pre_call_checks=True,
|
||||
debug_level="DEBUG",
|
||||
)
|
||||
|
||||
|
@ -115,15 +124,23 @@ def test_rate_limit(
|
|||
ExpectNoException: Signfies that no other error has happened. A NOP
|
||||
"""
|
||||
# Can send more messages then we're going to; so don't expect a rate limit error
|
||||
args = locals()
|
||||
print(f"args: {args}")
|
||||
expected_exception = (
|
||||
ExpectNoException if num_try_send <= num_allowed_send else ValueError
|
||||
)
|
||||
|
||||
# if (
|
||||
# num_try_send > num_allowed_send and sync_mode == False
|
||||
# ): # async calls are made simultaneously - the check for collision would need to happen before the router call
|
||||
# return
|
||||
|
||||
list_of_messages = generate_list_of_messages(max(num_try_send, num_allowed_send))
|
||||
rpm, tpm = calculate_limits(list_of_messages[:num_allowed_send])
|
||||
list_of_messages = list_of_messages[:num_try_send]
|
||||
router = router_factory(rpm, tpm, routing_strategy)
|
||||
router: Router = router_factory(rpm, tpm, routing_strategy)
|
||||
|
||||
print(f"router: {router.model_list}")
|
||||
with pytest.raises(expected_exception) as excinfo: # asserts correct type raised
|
||||
if sync_mode:
|
||||
results = sync_call(router, list_of_messages)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue