forked from phoenix/litellm-mirror
fix(router.py): make router async calls coroutine safe
uses pre-call checks to check if a call is below it's rpm limit, works even if multiple async calls are made simultaneously
This commit is contained in:
parent
a101591f74
commit
0d1cca9aa0
2 changed files with 72 additions and 21 deletions
|
@ -522,7 +522,7 @@ class Router:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
if self.set_verbose == True and self.debug_level == "DEBUG":
|
|
||||||
# debug how often this deployment picked
|
# debug how often this deployment picked
|
||||||
self._print_deployment_metrics(deployment=deployment)
|
self._print_deployment_metrics(deployment=deployment)
|
||||||
|
|
||||||
|
@ -582,9 +582,9 @@ class Router:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
)
|
)
|
||||||
if self.set_verbose == True and self.debug_level == "DEBUG":
|
|
||||||
# debug how often this deployment picked
|
# debug how often this deployment picked
|
||||||
self._print_deployment_metrics(deployment=deployment, response=response)
|
self._print_deployment_metrics(deployment=deployment, response=response)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
|
@ -2360,6 +2360,8 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _returned_deployments
|
return _returned_deployments
|
||||||
|
|
||||||
|
_context_window_error = False
|
||||||
|
_rate_limit_error = False
|
||||||
for idx, deployment in enumerate(_returned_deployments):
|
for idx, deployment in enumerate(_returned_deployments):
|
||||||
# see if we have the info for this model
|
# see if we have the info for this model
|
||||||
try:
|
try:
|
||||||
|
@ -2384,17 +2386,46 @@ class Router:
|
||||||
and input_tokens > model_info["max_input_tokens"]
|
and input_tokens > model_info["max_input_tokens"]
|
||||||
):
|
):
|
||||||
invalid_model_indices.append(idx)
|
invalid_model_indices.append(idx)
|
||||||
|
_context_window_error = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
## TPM/RPM CHECK ##
|
||||||
|
_litellm_params = deployment.get("litellm_params", {})
|
||||||
|
_model_id = deployment.get("model_info", {}).get("id", "")
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(_litellm_params, dict)
|
||||||
|
and _litellm_params.get("rpm", None) is not None
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
isinstance(_litellm_params["rpm"], int)
|
||||||
|
and _model_id in self.deployment_stats
|
||||||
|
and _litellm_params["rpm"]
|
||||||
|
<= self.deployment_stats[_model_id]["num_requests"]
|
||||||
|
):
|
||||||
|
invalid_model_indices.append(idx)
|
||||||
|
_rate_limit_error = True
|
||||||
|
continue
|
||||||
|
|
||||||
if len(invalid_model_indices) == len(_returned_deployments):
|
if len(invalid_model_indices) == len(_returned_deployments):
|
||||||
"""
|
"""
|
||||||
- no healthy deployments available b/c context window checks
|
- no healthy deployments available b/c context window checks or rate limit error
|
||||||
|
|
||||||
|
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if _rate_limit_error == True: # allow generic fallback logic to take place
|
||||||
|
raise ValueError(
|
||||||
|
f"No deployments available for selected model, passed model={model}"
|
||||||
|
)
|
||||||
|
elif _context_window_error == True:
|
||||||
raise litellm.ContextWindowExceededError(
|
raise litellm.ContextWindowExceededError(
|
||||||
message="Context Window exceeded for given call",
|
message="Context Window exceeded for given call",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="",
|
llm_provider="",
|
||||||
response=httpx.Response(
|
response=httpx.Response(
|
||||||
status_code=400, request=httpx.Request("GET", "https://example.com")
|
status_code=400,
|
||||||
|
request=httpx.Request("GET", "https://example.com"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if len(invalid_model_indices) > 0:
|
if len(invalid_model_indices) > 0:
|
||||||
|
@ -2606,13 +2637,16 @@ class Router:
|
||||||
"num_successes": 1,
|
"num_successes": 1,
|
||||||
"avg_latency": response_ms,
|
"avg_latency": response_ms,
|
||||||
}
|
}
|
||||||
|
if self.set_verbose == True and self.debug_level == "DEBUG":
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
# Assuming self.deployment_stats is your dictionary
|
# Assuming self.deployment_stats is your dictionary
|
||||||
formatted_stats = pformat(self.deployment_stats)
|
formatted_stats = pformat(self.deployment_stats)
|
||||||
|
|
||||||
# Assuming verbose_router_logger is your logger
|
# Assuming verbose_router_logger is your logger
|
||||||
verbose_router_logger.info("self.deployment_stats: \n%s", formatted_stats)
|
verbose_router_logger.info(
|
||||||
|
"self.deployment_stats: \n%s", formatted_stats
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_router_logger.error(f"Error in _print_deployment_metrics: {str(e)}")
|
verbose_router_logger.error(f"Error in _print_deployment_metrics: {str(e)}")
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,14 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import random
|
import random
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
import sys
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from litellm import utils, Router
|
from litellm import utils, Router
|
||||||
|
@ -35,6 +43,7 @@ def router_factory():
|
||||||
return Router(
|
return Router(
|
||||||
model_list=model_list,
|
model_list=model_list,
|
||||||
routing_strategy=routing_strategy,
|
routing_strategy=routing_strategy,
|
||||||
|
enable_pre_call_checks=True,
|
||||||
debug_level="DEBUG",
|
debug_level="DEBUG",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -115,15 +124,23 @@ def test_rate_limit(
|
||||||
ExpectNoException: Signfies that no other error has happened. A NOP
|
ExpectNoException: Signfies that no other error has happened. A NOP
|
||||||
"""
|
"""
|
||||||
# Can send more messages then we're going to; so don't expect a rate limit error
|
# Can send more messages then we're going to; so don't expect a rate limit error
|
||||||
|
args = locals()
|
||||||
|
print(f"args: {args}")
|
||||||
expected_exception = (
|
expected_exception = (
|
||||||
ExpectNoException if num_try_send <= num_allowed_send else ValueError
|
ExpectNoException if num_try_send <= num_allowed_send else ValueError
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if (
|
||||||
|
# num_try_send > num_allowed_send and sync_mode == False
|
||||||
|
# ): # async calls are made simultaneously - the check for collision would need to happen before the router call
|
||||||
|
# return
|
||||||
|
|
||||||
list_of_messages = generate_list_of_messages(max(num_try_send, num_allowed_send))
|
list_of_messages = generate_list_of_messages(max(num_try_send, num_allowed_send))
|
||||||
rpm, tpm = calculate_limits(list_of_messages[:num_allowed_send])
|
rpm, tpm = calculate_limits(list_of_messages[:num_allowed_send])
|
||||||
list_of_messages = list_of_messages[:num_try_send]
|
list_of_messages = list_of_messages[:num_try_send]
|
||||||
router = router_factory(rpm, tpm, routing_strategy)
|
router: Router = router_factory(rpm, tpm, routing_strategy)
|
||||||
|
|
||||||
|
print(f"router: {router.model_list}")
|
||||||
with pytest.raises(expected_exception) as excinfo: # asserts correct type raised
|
with pytest.raises(expected_exception) as excinfo: # asserts correct type raised
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
results = sync_call(router, list_of_messages)
|
results = sync_call(router, list_of_messages)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue