Merge pull request #3153 from BerriAI/litellm_usage_based_routing_v2_improvements

usage based routing v2 improvements - unit testing + *NEW* async + sync 'pre_call_checks'
This commit is contained in:
Krish Dholakia 2024-04-18 22:16:16 -07:00 committed by GitHub
commit f1340b52dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 723 additions and 43 deletions

View file

@ -21,7 +21,9 @@ class ServiceLogging(CustomLogger):
if "prometheus_system" in litellm.service_callback: if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger() self.prometheusServicesLogger = PrometheusServicesLogger()
def service_success_hook(self, service: ServiceTypes, duration: float): def service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
""" """
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
""" """
@ -29,7 +31,7 @@ class ServiceLogging(CustomLogger):
self.mock_testing_sync_success_hook += 1 self.mock_testing_sync_success_hook += 1
def service_failure_hook( def service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception self, service: ServiceTypes, duration: float, error: Exception, call_type: str
): ):
""" """
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).

View file

@ -89,6 +89,13 @@ class InMemoryCache(BaseCache):
return_val.append(val) return_val.append(val)
return return_val return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs) return self.get_cache(key=key, **kwargs)
@ -198,6 +205,42 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}" f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
) )
def increment_cache(self, key, value: int, **kwargs) -> int:
_redis_client = self.redis_client
start_time = time.time()
try:
result = _redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="increment_cache",
)
)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="increment_cache",
)
)
verbose_logger.error(
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
raise e
async def async_scan_iter(self, pattern: str, count: int = 100) -> list: async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
start_time = time.time() start_time = time.time()
try: try:
@ -302,6 +345,10 @@ class RedisCache(BaseCache):
""" """
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
start_time = time.time() start_time = time.time()
print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
)
try: try:
async with _redis_client as redis_client: async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe: async with redis_client.pipeline(transaction=True) as pipe:
@ -1093,6 +1140,30 @@ class DualCache(BaseCache):
except Exception as e: except Exception as e:
print_verbose(e) print_verbose(e)
def increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
result = self.redis_cache.increment_cache(key, value, **kwargs)
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
raise e
def get_cache(self, key, local_only: bool = False, **kwargs): def get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
try: try:
@ -1198,7 +1269,6 @@ class DualCache(BaseCache):
print_verbose(f"in_memory_result: {in_memory_result}") print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None: if in_memory_result is not None:
result = in_memory_result result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False: if None in result and self.redis_cache is not None and local_only == False:
""" """
- for the none values in the result - for the none values in the result
@ -1214,14 +1284,12 @@ class DualCache(BaseCache):
if redis_result is not None: if redis_result is not None:
# Update in-memory cache with the value from Redis # Update in-memory cache with the value from Redis
for key in redis_result: for key, value in redis_result.items():
await self.in_memory_cache.async_set_cache( if value is not None:
key, redis_result[key], **kwargs await self.in_memory_cache.async_set_cache(
) key, redis_result[key], **kwargs
)
sublist_dict = dict(zip(sublist_keys, redis_result)) for key, value in redis_result.items():
for key, value in sublist_dict.items():
result[sublist_keys.index(key)] = value result[sublist_keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}") print_verbose(f"async batch get cache: cache result: {result}")
@ -1230,6 +1298,9 @@ class DualCache(BaseCache):
traceback.print_exc() traceback.print_exc()
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
print_verbose(
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
)
try: try:
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache(key, value, **kwargs) await self.in_memory_cache.async_set_cache(key, value, **kwargs)

View file

@ -6,7 +6,7 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal, Union from typing import Literal, Union, Optional
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -46,6 +46,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass pass
#### PRE-CALL CHECKS - router/proxy only ####
"""
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
"""
async def async_pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
def pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
#### CALL HOOKS - proxy only #### #### CALL HOOKS - proxy only ####
""" """
Control the modify incoming / outgoung data before calling the model Control the modify incoming / outgoung data before calling the model

View file

@ -31,6 +31,7 @@ import copy
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
import logging import logging
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
from litellm.integrations.custom_logger import CustomLogger
class Router: class Router:
@ -379,6 +380,9 @@ class Router:
else: else:
model_client = potential_model_client model_client = potential_model_client
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.completion( response = litellm.completion(
**{ **{
**data, **data,
@ -391,6 +395,7 @@ class Router:
verbose_router_logger.info( verbose_router_logger.info(
f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m" f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
) )
return response return response
except Exception as e: except Exception as e:
verbose_router_logger.info( verbose_router_logger.info(
@ -494,18 +499,20 @@ class Router:
deployment=deployment, kwargs=kwargs, client_type="rpm_client" deployment=deployment, kwargs=kwargs, client_type="rpm_client"
) )
if ( if rpm_semaphore is not None and isinstance(
rpm_semaphore is not None rpm_semaphore, asyncio.Semaphore
and isinstance(rpm_semaphore, asyncio.Semaphore)
and self.routing_strategy == "usage-based-routing-v2"
): ):
async with rpm_semaphore: async with rpm_semaphore:
""" """
- Check rpm limits before making the call - Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
""" """
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment) await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await _response response = await _response
else: else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await _response response = await _response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
@ -580,6 +587,10 @@ class Router:
model_client = potential_model_client model_client = potential_model_client
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.image_generation( response = litellm.image_generation(
**{ **{
**data, **data,
@ -658,7 +669,7 @@ class Router:
model_client = potential_model_client model_client = potential_model_client
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
response = await litellm.aimage_generation( response = litellm.aimage_generation(
**{ **{
**data, **data,
"prompt": prompt, "prompt": prompt,
@ -667,6 +678,28 @@ class Router:
**kwargs, **kwargs,
} }
) )
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
verbose_router_logger.info( verbose_router_logger.info(
f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m" f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m"
@ -758,7 +791,7 @@ class Router:
model_client = potential_model_client model_client = potential_model_client
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
response = await litellm.atranscription( response = litellm.atranscription(
**{ **{
**data, **data,
"file": file, "file": file,
@ -767,6 +800,28 @@ class Router:
**kwargs, **kwargs,
} }
) )
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
verbose_router_logger.info( verbose_router_logger.info(
f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m" f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
@ -981,7 +1036,8 @@ class Router:
else: else:
model_client = potential_model_client model_client = potential_model_client
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
response = await litellm.atext_completion(
response = litellm.atext_completion(
**{ **{
**data, **data,
"prompt": prompt, "prompt": prompt,
@ -991,6 +1047,27 @@ class Router:
**kwargs, **kwargs,
} }
) )
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
verbose_router_logger.info( verbose_router_logger.info(
f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m" f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m"
@ -1065,6 +1142,10 @@ class Router:
model_client = potential_model_client model_client = potential_model_client
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.embedding( response = litellm.embedding(
**{ **{
**data, **data,
@ -1150,7 +1231,7 @@ class Router:
model_client = potential_model_client model_client = potential_model_client
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
response = await litellm.aembedding( response = litellm.aembedding(
**{ **{
**data, **data,
"input": input, "input": input,
@ -1159,6 +1240,28 @@ class Router:
**kwargs, **kwargs,
} }
) )
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
verbose_router_logger.info( verbose_router_logger.info(
f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m" f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m"
@ -1716,6 +1819,38 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models return cooldown_models
def routing_strategy_pre_call_checks(self, deployment: dict):
"""
Mimics 'async_routing_strategy_pre_call_checks'
Ensures consistent update rpm implementation for 'usage-based-routing-v2'
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = _callback.pre_call_check(deployment)
async def async_routing_strategy_pre_call_checks(self, deployment: dict):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = await _callback.async_pre_call_check(deployment)
def set_client(self, model: dict): def set_client(self, model: dict):
""" """
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
@ -2704,6 +2839,7 @@ class Router:
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
) )
return deployment return deployment
def get_available_deployment( def get_available_deployment(

View file

@ -39,7 +39,81 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
async def pre_call_rpm_check(self, deployment: dict) -> dict: def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
"""
Pre-call check + update model rpm
Returns - deployment
Raises - RateLimitError if deployment over defined RPM limit
"""
try:
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id")
rpm_key = f"{model_id}:rpm:{current_minute}"
local_result = self.router_cache.get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None
if deployment_rpm is None:
deployment_rpm = deployment.get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("model_info", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = self.router_cache.increment_cache(key=rpm_key, value=1)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return deployment
except Exception as e:
if isinstance(e, litellm.RateLimitError):
raise e
return deployment # don't fail calls if eg. redis fails to connect
async def async_pre_call_check(self, deployment: Dict) -> Optional[Dict]:
""" """
Pre-call check + update model rpm Pre-call check + update model rpm
- Used inside semaphore - Used inside semaphore
@ -58,8 +132,8 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# ------------ # ------------
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
model_group = deployment.get("model_name", "") model_id = deployment.get("model_info", {}).get("id")
rpm_key = f"{model_group}:rpm:{current_minute}" rpm_key = f"{model_id}:rpm:{current_minute}"
local_result = await self.router_cache.async_get_cache( local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True key=rpm_key, local_only=True
) # check local result first ) # check local result first
@ -143,26 +217,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# Setup values # Setup values
# ------------ # ------------
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime(
tpm_key = f"{model_group}:tpm:{current_minute}" "%H-%M"
rpm_key = f"{model_group}:rpm:{current_minute}" ) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}"
# ------------ # ------------
# Update usage # Update usage
# ------------ # ------------
# update cache
## TPM ## TPM
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} self.router_cache.increment_cache(key=tpm_key, value=total_tokens)
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
self.logged_success += 1 self.logged_success += 1
@ -254,21 +320,26 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
for deployment in healthy_deployments: for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0 tpm_dict[deployment["model_info"]["id"]] = 0
else: else:
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
for d in healthy_deployments: for d in healthy_deployments:
## if healthy deployment not yet used ## if healthy deployment not yet used
if d["model_info"]["id"] not in tpm_dict: tpm_key = f"{d['model_info']['id']}:tpm:{current_minute}"
tpm_dict[d["model_info"]["id"]] = 0 if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
tpm_dict[tpm_key] = 0
all_deployments = tpm_dict all_deployments = tpm_dict
deployment = None deployment = None
for item, item_tpm in all_deployments.items(): for item, item_tpm in all_deployments.items():
## get the item from model list ## get the item from model list
_deployment = None _deployment = None
item = item.split(":")[0]
for m in healthy_deployments: for m in healthy_deployments:
if item == m["model_info"]["id"]: if item == m["model_info"]["id"]:
_deployment = m _deployment = m
if _deployment is None: if _deployment is None:
continue # skip to next one continue # skip to next one
@ -291,7 +362,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
_deployment_rpm = _deployment.get("model_info", {}).get("rpm") _deployment_rpm = _deployment.get("model_info", {}).get("rpm")
if _deployment_rpm is None: if _deployment_rpm is None:
_deployment_rpm = float("inf") _deployment_rpm = float("inf")
if item_tpm + input_tokens > _deployment_tpm: if item_tpm + input_tokens > _deployment_tpm:
continue continue
elif (rpm_dict is not None and item in rpm_dict) and ( elif (rpm_dict is not None and item in rpm_dict) and (

View file

@ -0,0 +1,390 @@
#### What this tests ####
# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2-v2'
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm import Router
import litellm
from litellm.router_strategy.lowest_tpm_rpm_v2 import (
LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler,
)
from litellm.utils import get_utc_datetime
from litellm.caching import DualCache
### UNIT TESTS FOR TPM/RPM ROUTING ###
def test_tpm_rpm_updated():
test_cache = DualCache()
model_list = []
lowest_tpm_logger = LowestTPMLoggingHandler(
router_cache=test_cache, model_list=model_list
)
model_group = "gpt-3.5-turbo"
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
end_time = time.time()
lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"])
lowest_tpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}"
rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}"
print(f"tpm_count_api_key={tpm_count_api_key}")
assert response_obj["usage"]["total_tokens"] == test_cache.get_cache(
key=tpm_count_api_key
)
assert 1 == test_cache.get_cache(key=rpm_count_api_key)
# test_tpm_rpm_updated()
def test_get_available_deployments():
test_cache = DualCache()
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "azure/chatgpt-v-2"},
"model_info": {"id": "1234"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "azure/chatgpt-v-2"},
"model_info": {"id": "5678"},
},
]
lowest_tpm_logger = LowestTPMLoggingHandler(
router_cache=test_cache, model_list=model_list
)
model_group = "gpt-3.5-turbo"
## DEPLOYMENT 1 ##
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
end_time = time.time()
lowest_tpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## DEPLOYMENT 2 ##
deployment_id = "5678"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}}
end_time = time.time()
lowest_tpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
assert (
lowest_tpm_logger.get_available_deployments(
model_group=model_group,
healthy_deployments=model_list,
input=["Hello world"],
)["model_info"]["id"]
== "5678"
)
# test_get_available_deployments()
def test_router_get_available_deployments():
"""
Test if routers 'get_available_deployments' returns the least busy deployment
"""
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing-v2",
set_verbose=False,
num_retries=3,
) # type: ignore
print(f"router id's: {router.get_model_ids()}")
## DEPLOYMENT 1 ##
deployment_id = 1
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": 1},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
end_time = time.time()
router.lowesttpm_logger_v2.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## DEPLOYMENT 2 ##
deployment_id = 2
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": 2},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}}
end_time = time.time()
router.lowesttpm_logger_v2.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model"))
assert (
router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2"
)
# test_get_available_deployments()
# test_router_get_available_deployments()
def test_router_skip_rate_limited_deployments():
"""
Test if routers 'get_available_deployments' raises No Models Available error if max tpm would be reached by message
"""
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"tpm": 1440,
},
"model_info": {"id": 1},
},
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing-v2",
set_verbose=False,
num_retries=3,
) # type: ignore
## DEPLOYMENT 1 ##
deployment_id = 1
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 1439}}
end_time = time.time()
router.lowesttpm_logger_v2.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model"))
try:
router.get_available_deployment(
model="azure-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
pytest.fail(f"Should have raised No Models Available error")
except Exception as e:
print(f"An exception occurred! {str(e)}")
def test_single_deployment_tpm_zero():
import litellm
import os
from datetime import datetime
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
"tpm": 0,
},
}
]
router = litellm.Router(
model_list=model_list,
routing_strategy="usage-based-routing-v2",
cache_responses=True,
)
model = "gpt-3.5-turbo"
messages = [{"content": "Hello, how are you?", "role": "user"}]
try:
router.get_available_deployment(
model=model,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
pytest.fail(f"Should have raised No Models Available error")
except Exception as e:
print(f"it worked - {str(e)}! \n{traceback.format_exc()}")
@pytest.mark.asyncio
async def test_router_completion_streaming():
messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
]
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing-v2",
set_verbose=False,
) # type: ignore
### Make 3 calls, test if 3rd call goes to lowest tpm deployment
## CALL 1+2
tasks = []
response = None
final_response = None
for _ in range(2):
tasks.append(router.acompletion(model=model, messages=messages))
response = await asyncio.gather(*tasks)
if response is not None:
## CALL 3
await asyncio.sleep(1) # let the token update happen
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
picked_deployment = router.lowesttpm_logger_v2.get_available_deployments(
model_group=model,
healthy_deployments=router.healthy_deployments,
messages=messages,
)
final_response = await router.acompletion(model=model, messages=messages)
print(f"min deployment id: {picked_deployment}")
tpm_key = f"{model}:tpm:{current_minute}"
rpm_key = f"{model}:rpm:{current_minute}"
tpm_dict = router.cache.get_cache(key=tpm_key)
print(f"tpm_dict: {tpm_dict}")
rpm_dict = router.cache.get_cache(key=rpm_key)
print(f"rpm_dict: {rpm_dict}")
print(f"model id: {final_response._hidden_params['model_id']}")
assert (
final_response._hidden_params["model_id"]
== picked_deployment["model_info"]["id"]
)
# asyncio.run(test_router_completion_streaming())
"""
- Unit test for sync 'pre_call_checks'
- Unit test for async 'async_pre_call_checks'
"""