fix(test_lowest_tpm_rpm_routing_v2.py): unit testing for usage-based-routing-v2

This commit is contained in:
Krrish Dholakia 2024-04-18 21:38:00 -07:00
parent 72691e05f4
commit 376ee4e9d7
6 changed files with 171 additions and 53 deletions

View file

@ -31,6 +31,7 @@ import copy
from litellm._logging import verbose_router_logger
import logging
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
from litellm.integrations.custom_logger import CustomLogger
class Router:
@ -492,18 +493,18 @@ class Router:
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if (
rpm_semaphore is not None
and isinstance(rpm_semaphore, asyncio.Semaphore)
and self.routing_strategy == "usage-based-routing-v2"
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
await self.routing_strategy_pre_call_checks(deployment=deployment)
response = await _response
else:
await self.routing_strategy_pre_call_checks(deployment=deployment)
response = await _response
self.success_calls[model_name] += 1
@ -1712,6 +1713,22 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
async def routing_strategy_pre_call_checks(self, deployment: dict):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = await _callback.async_pre_call_check(deployment)
def set_client(self, model: dict):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
@ -2700,6 +2717,7 @@ class Router:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
)
return deployment
def get_available_deployment(