fix(router.py): make router async calls coroutine safe

uses pre-call checks to check if a call is below it's rpm limit, works even if multiple async calls are
made simultaneously
This commit is contained in:
Krrish Dholakia 2024-04-06 17:31:07 -07:00
parent a101591f74
commit 0d1cca9aa0
2 changed files with 72 additions and 21 deletions

View file

@ -522,9 +522,9 @@ class Router:
messages=messages,
specific_deployment=kwargs.pop("specific_deployment", None),
)
if self.set_verbose == True and self.debug_level == "DEBUG":
# debug how often this deployment picked
self._print_deployment_metrics(deployment=deployment)
# debug how often this deployment picked
self._print_deployment_metrics(deployment=deployment)
kwargs.setdefault("metadata", {}).update(
{
@ -582,9 +582,9 @@ class Router:
verbose_router_logger.info(
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
)
if self.set_verbose == True and self.debug_level == "DEBUG":
# debug how often this deployment picked
self._print_deployment_metrics(deployment=deployment, response=response)
# debug how often this deployment picked
self._print_deployment_metrics(deployment=deployment, response=response)
return response
except Exception as e:
verbose_router_logger.info(
@ -2360,6 +2360,8 @@ class Router:
except Exception as e:
return _returned_deployments
_context_window_error = False
_rate_limit_error = False
for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model
try:
@ -2384,19 +2386,48 @@ class Router:
and input_tokens > model_info["max_input_tokens"]
):
invalid_model_indices.append(idx)
_context_window_error = True
continue
## TPM/RPM CHECK ##
_litellm_params = deployment.get("litellm_params", {})
_model_id = deployment.get("model_info", {}).get("id", "")
if (
isinstance(_litellm_params, dict)
and _litellm_params.get("rpm", None) is not None
):
if (
isinstance(_litellm_params["rpm"], int)
and _model_id in self.deployment_stats
and _litellm_params["rpm"]
<= self.deployment_stats[_model_id]["num_requests"]
):
invalid_model_indices.append(idx)
_rate_limit_error = True
continue
if len(invalid_model_indices) == len(_returned_deployments):
"""
- no healthy deployments available b/c context window checks
- no healthy deployments available b/c context window checks or rate limit error
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
"""
raise litellm.ContextWindowExceededError(
message="Context Window exceeded for given call",
model=model,
llm_provider="",
response=httpx.Response(
status_code=400, request=httpx.Request("GET", "https://example.com")
),
)
if _rate_limit_error == True: # allow generic fallback logic to take place
raise ValueError(
f"No deployments available for selected model, passed model={model}"
)
elif _context_window_error == True:
raise litellm.ContextWindowExceededError(
message="Context Window exceeded for given call",
model=model,
llm_provider="",
response=httpx.Response(
status_code=400,
request=httpx.Request("GET", "https://example.com"),
),
)
if len(invalid_model_indices) > 0:
for idx in reversed(invalid_model_indices):
_returned_deployments.pop(idx)
@ -2606,13 +2637,16 @@ class Router:
"num_successes": 1,
"avg_latency": response_ms,
}
from pprint import pformat
if self.set_verbose == True and self.debug_level == "DEBUG":
from pprint import pformat
# Assuming self.deployment_stats is your dictionary
formatted_stats = pformat(self.deployment_stats)
# Assuming self.deployment_stats is your dictionary
formatted_stats = pformat(self.deployment_stats)
# Assuming verbose_router_logger is your logger
verbose_router_logger.info("self.deployment_stats: \n%s", formatted_stats)
# Assuming verbose_router_logger is your logger
verbose_router_logger.info(
"self.deployment_stats: \n%s", formatted_stats
)
except Exception as e:
verbose_router_logger.error(f"Error in _print_deployment_metrics: {str(e)}")