forked from phoenix/litellm-mirror
fix(router.py): make router async calls coroutine safe
uses pre-call checks to check if a call is below it's rpm limit, works even if multiple async calls are made simultaneously
This commit is contained in:
parent
a101591f74
commit
0d1cca9aa0
2 changed files with 72 additions and 21 deletions
|
@ -522,9 +522,9 @@ class Router:
|
|||
messages=messages,
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
if self.set_verbose == True and self.debug_level == "DEBUG":
|
||||
# debug how often this deployment picked
|
||||
self._print_deployment_metrics(deployment=deployment)
|
||||
|
||||
# debug how often this deployment picked
|
||||
self._print_deployment_metrics(deployment=deployment)
|
||||
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{
|
||||
|
@ -582,9 +582,9 @@ class Router:
|
|||
verbose_router_logger.info(
|
||||
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
if self.set_verbose == True and self.debug_level == "DEBUG":
|
||||
# debug how often this deployment picked
|
||||
self._print_deployment_metrics(deployment=deployment, response=response)
|
||||
# debug how often this deployment picked
|
||||
self._print_deployment_metrics(deployment=deployment, response=response)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
|
@ -2360,6 +2360,8 @@ class Router:
|
|||
except Exception as e:
|
||||
return _returned_deployments
|
||||
|
||||
_context_window_error = False
|
||||
_rate_limit_error = False
|
||||
for idx, deployment in enumerate(_returned_deployments):
|
||||
# see if we have the info for this model
|
||||
try:
|
||||
|
@ -2384,19 +2386,48 @@ class Router:
|
|||
and input_tokens > model_info["max_input_tokens"]
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
_context_window_error = True
|
||||
continue
|
||||
|
||||
## TPM/RPM CHECK ##
|
||||
_litellm_params = deployment.get("litellm_params", {})
|
||||
_model_id = deployment.get("model_info", {}).get("id", "")
|
||||
|
||||
if (
|
||||
isinstance(_litellm_params, dict)
|
||||
and _litellm_params.get("rpm", None) is not None
|
||||
):
|
||||
if (
|
||||
isinstance(_litellm_params["rpm"], int)
|
||||
and _model_id in self.deployment_stats
|
||||
and _litellm_params["rpm"]
|
||||
<= self.deployment_stats[_model_id]["num_requests"]
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
_rate_limit_error = True
|
||||
continue
|
||||
|
||||
if len(invalid_model_indices) == len(_returned_deployments):
|
||||
"""
|
||||
- no healthy deployments available b/c context window checks
|
||||
- no healthy deployments available b/c context window checks or rate limit error
|
||||
|
||||
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
|
||||
"""
|
||||
raise litellm.ContextWindowExceededError(
|
||||
message="Context Window exceeded for given call",
|
||||
model=model,
|
||||
llm_provider="",
|
||||
response=httpx.Response(
|
||||
status_code=400, request=httpx.Request("GET", "https://example.com")
|
||||
),
|
||||
)
|
||||
|
||||
if _rate_limit_error == True: # allow generic fallback logic to take place
|
||||
raise ValueError(
|
||||
f"No deployments available for selected model, passed model={model}"
|
||||
)
|
||||
elif _context_window_error == True:
|
||||
raise litellm.ContextWindowExceededError(
|
||||
message="Context Window exceeded for given call",
|
||||
model=model,
|
||||
llm_provider="",
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
request=httpx.Request("GET", "https://example.com"),
|
||||
),
|
||||
)
|
||||
if len(invalid_model_indices) > 0:
|
||||
for idx in reversed(invalid_model_indices):
|
||||
_returned_deployments.pop(idx)
|
||||
|
@ -2606,13 +2637,16 @@ class Router:
|
|||
"num_successes": 1,
|
||||
"avg_latency": response_ms,
|
||||
}
|
||||
from pprint import pformat
|
||||
if self.set_verbose == True and self.debug_level == "DEBUG":
|
||||
from pprint import pformat
|
||||
|
||||
# Assuming self.deployment_stats is your dictionary
|
||||
formatted_stats = pformat(self.deployment_stats)
|
||||
# Assuming self.deployment_stats is your dictionary
|
||||
formatted_stats = pformat(self.deployment_stats)
|
||||
|
||||
# Assuming verbose_router_logger is your logger
|
||||
verbose_router_logger.info("self.deployment_stats: \n%s", formatted_stats)
|
||||
# Assuming verbose_router_logger is your logger
|
||||
verbose_router_logger.info(
|
||||
"self.deployment_stats: \n%s", formatted_stats
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(f"Error in _print_deployment_metrics: {str(e)}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue