forked from phoenix/litellm-mirror
fix(test_lowest_tpm_rpm_routing_v2.py): unit testing for usage-based-routing-v2
This commit is contained in:
parent
a05f148c17
commit
81573b2dd9
6 changed files with 171 additions and 53 deletions
|
@ -21,7 +21,9 @@ class ServiceLogging(CustomLogger):
|
|||
if "prometheus_system" in litellm.service_callback:
|
||||
self.prometheusServicesLogger = PrometheusServicesLogger()
|
||||
|
||||
def service_success_hook(self, service: ServiceTypes, duration: float):
|
||||
def service_success_hook(
|
||||
self, service: ServiceTypes, duration: float, call_type: str
|
||||
):
|
||||
"""
|
||||
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
||||
"""
|
||||
|
@ -29,7 +31,7 @@ class ServiceLogging(CustomLogger):
|
|||
self.mock_testing_sync_success_hook += 1
|
||||
|
||||
def service_failure_hook(
|
||||
self, service: ServiceTypes, duration: float, error: Exception
|
||||
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
|
||||
):
|
||||
"""
|
||||
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
||||
|
|
|
@ -217,6 +217,7 @@ class RedisCache(BaseCache):
|
|||
self.service_logger_obj.service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="increment_cache",
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
@ -226,11 +227,14 @@ class RedisCache(BaseCache):
|
|||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="increment_cache",
|
||||
)
|
||||
)
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
||||
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
value,
|
||||
)
|
||||
|
@ -278,6 +282,9 @@ class RedisCache(BaseCache):
|
|||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
start_time = time.time()
|
||||
print_verbose(
|
||||
f"Set Async Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
|
||||
)
|
||||
try:
|
||||
_redis_client = self.init_async_client()
|
||||
except Exception as e:
|
||||
|
@ -341,6 +348,10 @@ class RedisCache(BaseCache):
|
|||
"""
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
|
||||
print_verbose(
|
||||
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
||||
)
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
|
@ -1261,7 +1272,6 @@ class DualCache(BaseCache):
|
|||
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if None in result and self.redis_cache is not None and local_only == False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
|
@ -1277,14 +1287,12 @@ class DualCache(BaseCache):
|
|||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
for key in redis_result:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result[key], **kwargs
|
||||
)
|
||||
|
||||
sublist_dict = dict(zip(sublist_keys, redis_result))
|
||||
|
||||
for key, value in sublist_dict.items():
|
||||
for key, value in redis_result.items():
|
||||
if value is not None:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result[key], **kwargs
|
||||
)
|
||||
for key, value in redis_result.items():
|
||||
result[sublist_keys.index(key)] = value
|
||||
|
||||
print_verbose(f"async batch get cache: cache result: {result}")
|
||||
|
@ -1293,6 +1301,9 @@ class DualCache(BaseCache):
|
|||
traceback.print_exc()
|
||||
|
||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
print_verbose(
|
||||
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||
|
|
|
@ -6,7 +6,7 @@ import requests
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union
|
||||
from typing import Literal, Union, Optional
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
@ -46,6 +46,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### PRE-CALL CHECKS - router/proxy only ####
|
||||
"""
|
||||
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
|
||||
"""
|
||||
|
||||
async def async_pre_call_check(self, deployment: dict) -> Optional[dict]:
|
||||
pass
|
||||
|
||||
def pre_call_check(self, deployment: dict) -> Optional[dict]:
|
||||
pass
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
"""
|
||||
Control the modify incoming / outgoung data before calling the model
|
||||
|
|
|
@ -31,6 +31,7 @@ import copy
|
|||
from litellm._logging import verbose_router_logger
|
||||
import logging
|
||||
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -492,18 +493,18 @@ class Router:
|
|||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if (
|
||||
rpm_semaphore is not None
|
||||
and isinstance(rpm_semaphore, asyncio.Semaphore)
|
||||
and self.routing_strategy == "usage-based-routing-v2"
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
|
||||
await self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await _response
|
||||
else:
|
||||
await self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await _response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
|
@ -1712,6 +1713,22 @@ class Router:
|
|||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
async def routing_strategy_pre_call_checks(self, deployment: dict):
|
||||
"""
|
||||
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
|
||||
|
||||
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
|
||||
"""
|
||||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger):
|
||||
response = await _callback.async_pre_call_check(deployment)
|
||||
|
||||
def set_client(self, model: dict):
|
||||
"""
|
||||
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||
|
@ -2700,6 +2717,7 @@ class Router:
|
|||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||
)
|
||||
|
||||
return deployment
|
||||
|
||||
def get_available_deployment(
|
||||
|
|
|
@ -39,7 +39,81 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
|
||||
async def pre_call_rpm_check(self, deployment: dict) -> dict:
|
||||
def pre_call_check(self, deployment: Dict) -> Dict | None:
|
||||
"""
|
||||
Pre-call check + update model rpm
|
||||
|
||||
Returns - deployment
|
||||
|
||||
Raises - RateLimitError if deployment over defined RPM limit
|
||||
"""
|
||||
try:
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
rpm_key = f"{model_id}:rpm:{current_minute}"
|
||||
local_result = self.router_cache.get_cache(
|
||||
key=rpm_key, local_only=True
|
||||
) # check local result first
|
||||
|
||||
deployment_rpm = None
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("model_info", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = float("inf")
|
||||
|
||||
if local_result is not None and local_result >= deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, local_result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
local_result,
|
||||
),
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
else:
|
||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||
result = self.router_cache.increment_cache(key=rpm_key, value=1)
|
||||
if result is not None and result > deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
result,
|
||||
),
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return deployment
|
||||
except Exception as e:
|
||||
if isinstance(e, litellm.RateLimitError):
|
||||
raise e
|
||||
return deployment # don't fail calls if eg. redis fails to connect
|
||||
|
||||
async def async_pre_call_check(self, deployment: Dict) -> Dict | None:
|
||||
"""
|
||||
Pre-call check + update model rpm
|
||||
- Used inside semaphore
|
||||
|
@ -58,8 +132,8 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
model_group = deployment.get("model_name", "")
|
||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
rpm_key = f"{model_id}:rpm:{current_minute}"
|
||||
local_result = await self.router_cache.async_get_cache(
|
||||
key=rpm_key, local_only=True
|
||||
) # check local result first
|
||||
|
@ -246,21 +320,26 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
for deployment in healthy_deployments:
|
||||
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||
else:
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in tpm_dict:
|
||||
tpm_dict[d["model_info"]["id"]] = 0
|
||||
tpm_key = f"{d['model_info']['id']}:tpm:{current_minute}"
|
||||
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
|
||||
tpm_dict[tpm_key] = 0
|
||||
|
||||
all_deployments = tpm_dict
|
||||
|
||||
deployment = None
|
||||
for item, item_tpm in all_deployments.items():
|
||||
## get the item from model list
|
||||
_deployment = None
|
||||
item = item.split(":")[0]
|
||||
for m in healthy_deployments:
|
||||
if item == m["model_info"]["id"]:
|
||||
_deployment = m
|
||||
|
||||
if _deployment is None:
|
||||
continue # skip to next one
|
||||
|
||||
|
@ -283,7 +362,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = float("inf")
|
||||
|
||||
if item_tpm + input_tokens > _deployment_tpm:
|
||||
continue
|
||||
elif (rpm_dict is not None and item in rpm_dict) and (
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#### What this tests ####
|
||||
# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2'
|
||||
# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2-v2'
|
||||
|
||||
import sys, os, asyncio, time, random
|
||||
from datetime import datetime
|
||||
|
@ -18,6 +18,7 @@ import litellm
|
|||
from litellm.router_strategy.lowest_tpm_rpm_v2 import (
|
||||
LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler,
|
||||
)
|
||||
from litellm.utils import get_utc_datetime
|
||||
from litellm.caching import DualCache
|
||||
|
||||
### UNIT TESTS FOR TPM/RPM ROUTING ###
|
||||
|
@ -43,20 +44,23 @@ def test_tpm_rpm_updated():
|
|||
start_time = time.time()
|
||||
response_obj = {"usage": {"total_tokens": 50}}
|
||||
end_time = time.time()
|
||||
lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"])
|
||||
lowest_tpm_logger.log_success_event(
|
||||
response_obj=response_obj,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
tpm_count_api_key = f"{model_group}:tpm:{current_minute}"
|
||||
rpm_count_api_key = f"{model_group}:rpm:{current_minute}"
|
||||
assert (
|
||||
response_obj["usage"]["total_tokens"]
|
||||
== test_cache.get_cache(key=tpm_count_api_key)[deployment_id]
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}"
|
||||
rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}"
|
||||
|
||||
print(f"tpm_count_api_key={tpm_count_api_key}")
|
||||
assert response_obj["usage"]["total_tokens"] == test_cache.get_cache(
|
||||
key=tpm_count_api_key
|
||||
)
|
||||
assert 1 == test_cache.get_cache(key=rpm_count_api_key)[deployment_id]
|
||||
assert 1 == test_cache.get_cache(key=rpm_count_api_key)
|
||||
|
||||
|
||||
# test_tpm_rpm_updated()
|
||||
|
@ -122,13 +126,6 @@ def test_get_available_deployments():
|
|||
)
|
||||
|
||||
## CHECK WHAT'S SELECTED ##
|
||||
print(
|
||||
lowest_tpm_logger.get_available_deployments(
|
||||
model_group=model_group,
|
||||
healthy_deployments=model_list,
|
||||
input=["Hello world"],
|
||||
)
|
||||
)
|
||||
assert (
|
||||
lowest_tpm_logger.get_available_deployments(
|
||||
model_group=model_group,
|
||||
|
@ -170,7 +167,7 @@ def test_router_get_available_deployments():
|
|||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="usage-based-routing",
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
set_verbose=False,
|
||||
num_retries=3,
|
||||
) # type: ignore
|
||||
|
@ -189,7 +186,7 @@ def test_router_get_available_deployments():
|
|||
start_time = time.time()
|
||||
response_obj = {"usage": {"total_tokens": 50}}
|
||||
end_time = time.time()
|
||||
router.lowesttpm_logger.log_success_event(
|
||||
router.lowesttpm_logger_v2.log_success_event(
|
||||
response_obj=response_obj,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
|
@ -208,7 +205,7 @@ def test_router_get_available_deployments():
|
|||
start_time = time.time()
|
||||
response_obj = {"usage": {"total_tokens": 20}}
|
||||
end_time = time.time()
|
||||
router.lowesttpm_logger.log_success_event(
|
||||
router.lowesttpm_logger_v2.log_success_event(
|
||||
response_obj=response_obj,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
|
@ -216,7 +213,7 @@ def test_router_get_available_deployments():
|
|||
)
|
||||
|
||||
## CHECK WHAT'S SELECTED ##
|
||||
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
|
||||
# print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model"))
|
||||
assert (
|
||||
router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2"
|
||||
)
|
||||
|
@ -244,7 +241,7 @@ def test_router_skip_rate_limited_deployments():
|
|||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="usage-based-routing",
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
set_verbose=False,
|
||||
num_retries=3,
|
||||
) # type: ignore
|
||||
|
@ -262,7 +259,7 @@ def test_router_skip_rate_limited_deployments():
|
|||
start_time = time.time()
|
||||
response_obj = {"usage": {"total_tokens": 1439}}
|
||||
end_time = time.time()
|
||||
router.lowesttpm_logger.log_success_event(
|
||||
router.lowesttpm_logger_v2.log_success_event(
|
||||
response_obj=response_obj,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
|
@ -270,7 +267,7 @@ def test_router_skip_rate_limited_deployments():
|
|||
)
|
||||
|
||||
## CHECK WHAT'S SELECTED ##
|
||||
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
|
||||
# print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model"))
|
||||
try:
|
||||
router.get_available_deployment(
|
||||
model="azure-model",
|
||||
|
@ -299,7 +296,7 @@ def test_single_deployment_tpm_zero():
|
|||
|
||||
router = litellm.Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="usage-based-routing",
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
cache_responses=True,
|
||||
)
|
||||
|
||||
|
@ -345,7 +342,7 @@ async def test_router_completion_streaming():
|
|||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="usage-based-routing",
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
set_verbose=False,
|
||||
) # type: ignore
|
||||
|
||||
|
@ -362,8 +359,9 @@ async def test_router_completion_streaming():
|
|||
if response is not None:
|
||||
## CALL 3
|
||||
await asyncio.sleep(1) # let the token update happen
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
picked_deployment = router.lowesttpm_logger.get_available_deployments(
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
picked_deployment = router.lowesttpm_logger_v2.get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=router.healthy_deployments,
|
||||
messages=messages,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue