forked from phoenix/litellm-mirror
feat(lowest_tpm_rpm_v2.py): move to using redis.incr and redis.mget for getting model usage from redis
makes routing work across multiple instances
This commit is contained in:
parent
b2741933dc
commit
180cf9bd5c
5 changed files with 437 additions and 12 deletions
|
@ -81,9 +81,29 @@ class InMemoryCache(BaseCache):
|
||||||
return cached_response
|
return cached_response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def batch_get_cache(self, keys: list, **kwargs):
|
||||||
|
return_val = []
|
||||||
|
for k in keys:
|
||||||
|
val = self.get_cache(key=k, **kwargs)
|
||||||
|
return_val.append(val)
|
||||||
|
return return_val
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
return self.get_cache(key=key, **kwargs)
|
return self.get_cache(key=key, **kwargs)
|
||||||
|
|
||||||
|
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||||
|
return_val = []
|
||||||
|
for k in keys:
|
||||||
|
val = self.get_cache(key=k, **kwargs)
|
||||||
|
return_val.append(val)
|
||||||
|
return return_val
|
||||||
|
|
||||||
|
async def async_increment(self, key, value: int, **kwargs):
|
||||||
|
# get the value
|
||||||
|
init_value = await self.async_get_cache(key=key) or 0
|
||||||
|
value = init_value + value
|
||||||
|
await self.async_set_cache(key, value, **kwargs)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.cache_dict.clear()
|
self.cache_dict.clear()
|
||||||
self.ttl_dict.clear()
|
self.ttl_dict.clear()
|
||||||
|
@ -246,6 +266,19 @@ class RedisCache(BaseCache):
|
||||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||||
await self.flush_cache_buffer()
|
await self.flush_cache_buffer()
|
||||||
|
|
||||||
|
async def async_increment(self, key, value: int, **kwargs):
|
||||||
|
_redis_client = self.init_async_client()
|
||||||
|
try:
|
||||||
|
async with _redis_client as redis_client:
|
||||||
|
await redis_client.incr(name=key, amount=value)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
||||||
|
str(e),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def flush_cache_buffer(self):
|
async def flush_cache_buffer(self):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
|
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
|
||||||
|
@ -283,6 +316,32 @@ class RedisCache(BaseCache):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||||
|
|
||||||
|
def batch_get_cache(self, key_list) -> dict:
|
||||||
|
"""
|
||||||
|
Use Redis for bulk read operations
|
||||||
|
"""
|
||||||
|
key_value_dict = {}
|
||||||
|
try:
|
||||||
|
_keys = []
|
||||||
|
for cache_key in key_list:
|
||||||
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
|
_keys.append(cache_key)
|
||||||
|
results = self.redis_client.mget(keys=_keys)
|
||||||
|
|
||||||
|
# Associate the results back with their keys.
|
||||||
|
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||||
|
key_value_dict = dict(zip(key_list, results))
|
||||||
|
|
||||||
|
decoded_results = {
|
||||||
|
k.decode("utf-8"): self._get_cache_logic(v)
|
||||||
|
for k, v in key_value_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoded_results
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"Error occurred in pipeline read - {str(e)}")
|
||||||
|
return key_value_dict
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
key = self.check_and_fix_namespace(key=key)
|
key = self.check_and_fix_namespace(key=key)
|
||||||
|
@ -301,7 +360,7 @@ class RedisCache(BaseCache):
|
||||||
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
|
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_cache_pipeline(self, key_list) -> dict:
|
async def async_batch_get_cache(self, key_list) -> dict:
|
||||||
"""
|
"""
|
||||||
Use Redis for bulk read operations
|
Use Redis for bulk read operations
|
||||||
"""
|
"""
|
||||||
|
@ -309,14 +368,11 @@ class RedisCache(BaseCache):
|
||||||
key_value_dict = {}
|
key_value_dict = {}
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
_keys = []
|
||||||
# Queue the get operations in the pipeline for all keys.
|
for cache_key in key_list:
|
||||||
for cache_key in key_list:
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
_keys.append(cache_key)
|
||||||
pipe.get(cache_key) # Queue GET command in pipeline
|
results = await redis_client.mget(keys=_keys)
|
||||||
|
|
||||||
# Execute the pipeline and await the results.
|
|
||||||
results = await pipe.execute()
|
|
||||||
|
|
||||||
# Associate the results back with their keys.
|
# Associate the results back with their keys.
|
||||||
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||||
|
@ -897,6 +953,39 @@ class DualCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
|
||||||
|
try:
|
||||||
|
result = [None for _ in range(len(keys))]
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
|
||||||
|
|
||||||
|
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||||
|
if in_memory_result is not None:
|
||||||
|
result = in_memory_result
|
||||||
|
|
||||||
|
if None in result and self.redis_cache is not None and local_only == False:
|
||||||
|
"""
|
||||||
|
- for the none values in the result
|
||||||
|
- check the redis cache
|
||||||
|
"""
|
||||||
|
sublist_keys = [
|
||||||
|
key for key, value in zip(keys, result) if value is None
|
||||||
|
]
|
||||||
|
# If not found in in-memory cache, try fetching from Redis
|
||||||
|
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
|
||||||
|
if redis_result is not None:
|
||||||
|
# Update in-memory cache with the value from Redis
|
||||||
|
for key in redis_result:
|
||||||
|
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
|
||||||
|
|
||||||
|
for key, value in redis_result.items():
|
||||||
|
result[sublist_keys.index(key)] = value
|
||||||
|
|
||||||
|
print_verbose(f"async batch get cache: cache result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
||||||
# Try to fetch from in-memory cache first
|
# Try to fetch from in-memory cache first
|
||||||
try:
|
try:
|
||||||
|
@ -930,6 +1019,50 @@ class DualCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_batch_get_cache(
|
||||||
|
self, keys: list, local_only: bool = False, **kwargs
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
result = [None for _ in range(len(keys))]
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
|
||||||
|
keys, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||||
|
if in_memory_result is not None:
|
||||||
|
result = in_memory_result
|
||||||
|
|
||||||
|
if None in result and self.redis_cache is not None and local_only == False:
|
||||||
|
"""
|
||||||
|
- for the none values in the result
|
||||||
|
- check the redis cache
|
||||||
|
"""
|
||||||
|
sublist_keys = [
|
||||||
|
key for key, value in zip(keys, result) if value is None
|
||||||
|
]
|
||||||
|
# If not found in in-memory cache, try fetching from Redis
|
||||||
|
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||||
|
sublist_keys, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if redis_result is not None:
|
||||||
|
# Update in-memory cache with the value from Redis
|
||||||
|
for key in redis_result:
|
||||||
|
await self.in_memory_cache.async_set_cache(
|
||||||
|
key, redis_result[key], **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
sublist_dict = dict(zip(sublist_keys, redis_result))
|
||||||
|
|
||||||
|
for key, value in sublist_dict.items():
|
||||||
|
result[sublist_keys.index(key)] = value[key]
|
||||||
|
|
||||||
|
print_verbose(f"async batch get cache: cache result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||||
try:
|
try:
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
|
@ -941,6 +1074,24 @@ class DualCache(BaseCache):
|
||||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_increment_cache(
|
||||||
|
self, key, value: int, local_only: bool = False, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Key - the key in cache
|
||||||
|
|
||||||
|
Value - int - the value you want to increment by
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
await self.in_memory_cache.async_increment(key, value, **kwargs)
|
||||||
|
|
||||||
|
if self.redis_cache is not None and local_only == False:
|
||||||
|
await self.redis_cache.async_increment(key, value, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
self.in_memory_cache.flush_cache()
|
self.in_memory_cache.flush_cache()
|
||||||
|
|
|
@ -28,7 +28,7 @@ litellm_settings:
|
||||||
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
|
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
routing_strategy: usage-based-routing
|
routing_strategy: usage-based-routing-v2
|
||||||
redis_host: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
|
redis_host: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
|
||||||
redis_password: madeBerri@992
|
redis_password: madeBerri@992
|
||||||
redis_port: 16337
|
redis_port: 16337
|
||||||
|
|
|
@ -79,7 +79,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
||||||
self.print_verbose(f"redis keys: {keys}")
|
self.print_verbose(f"redis keys: {keys}")
|
||||||
if len(keys) > 0:
|
if len(keys) > 0:
|
||||||
key_value_dict = (
|
key_value_dict = (
|
||||||
await litellm.cache.cache.async_get_cache_pipeline(
|
await litellm.cache.cache.async_batch_get_cache(
|
||||||
key_list=keys
|
key_list=keys
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from collections import defaultdict
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||||
CustomHTTPTransport,
|
CustomHTTPTransport,
|
||||||
AsyncCustomHTTPTransport,
|
AsyncCustomHTTPTransport,
|
||||||
|
@ -273,6 +274,12 @@ class Router:
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
||||||
|
elif routing_strategy == "usage-based-routing-v2":
|
||||||
|
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
|
||||||
|
router_cache=self.cache, model_list=self.model_list
|
||||||
|
)
|
||||||
|
if isinstance(litellm.callbacks, list):
|
||||||
|
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
|
||||||
elif routing_strategy == "latency-based-routing":
|
elif routing_strategy == "latency-based-routing":
|
||||||
self.lowestlatency_logger = LowestLatencyLoggingHandler(
|
self.lowestlatency_logger = LowestLatencyLoggingHandler(
|
||||||
router_cache=self.cache,
|
router_cache=self.cache,
|
||||||
|
@ -2506,7 +2513,16 @@ class Router:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
self.routing_strategy == "usage-based-routing-v2"
|
||||||
|
and self.lowesttpm_logger_v2 is not None
|
||||||
|
):
|
||||||
|
deployment = self.lowesttpm_logger_v2.get_available_deployments(
|
||||||
|
model_group=model,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, No deployment available"
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
|
|
258
litellm/router_strategy/lowest_tpm_rpm_v2.py
Normal file
258
litellm/router_strategy/lowest_tpm_rpm_v2.py
Normal file
|
@ -0,0 +1,258 @@
|
||||||
|
#### What this does ####
|
||||||
|
# identifies lowest tpm deployment
|
||||||
|
|
||||||
|
import dotenv, os, requests, random
|
||||||
|
from typing import Optional, Union, List, Dict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
import traceback, asyncio
|
||||||
|
from litellm import token_counter
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm._logging import verbose_router_logger
|
||||||
|
from litellm.utils import print_verbose
|
||||||
|
|
||||||
|
|
||||||
|
class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
|
"""
|
||||||
|
Updated version of TPM/RPM Logging.
|
||||||
|
|
||||||
|
Meant to work across instances.
|
||||||
|
|
||||||
|
Caches individual models, not model_groups
|
||||||
|
|
||||||
|
Uses batch get (redis.mget)
|
||||||
|
|
||||||
|
Increments tpm/rpm limit using redis.incr
|
||||||
|
"""
|
||||||
|
|
||||||
|
test_flag: bool = False
|
||||||
|
logged_success: int = 0
|
||||||
|
logged_failure: int = 0
|
||||||
|
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||||
|
|
||||||
|
def __init__(self, router_cache: DualCache, model_list: list):
|
||||||
|
self.router_cache = router_cache
|
||||||
|
self.model_list = model_list
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
Update TPM/RPM usage on success
|
||||||
|
"""
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
elif isinstance(id, int):
|
||||||
|
id = str(id)
|
||||||
|
|
||||||
|
total_tokens = response_obj["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
current_minute = datetime.now().strftime("%H-%M")
|
||||||
|
tpm_key = f"{model_group}:tpm:{current_minute}"
|
||||||
|
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||||
|
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||||
|
|
||||||
|
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
|
||||||
|
|
||||||
|
## RPM
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||||
|
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||||
|
|
||||||
|
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_success += 1
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
Update TPM/RPM usage on success
|
||||||
|
"""
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
elif isinstance(id, int):
|
||||||
|
id = str(id)
|
||||||
|
|
||||||
|
total_tokens = response_obj["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
current_minute = datetime.now().strftime("%H-%M")
|
||||||
|
|
||||||
|
tpm_key = f"{id}:tpm:{current_minute}"
|
||||||
|
rpm_key = f"{id}:rpm:{current_minute}"
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage
|
||||||
|
# ------------
|
||||||
|
# update cache
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
await self.router_cache.async_increment_cache(
|
||||||
|
key=tpm_key, value=total_tokens
|
||||||
|
)
|
||||||
|
## RPM
|
||||||
|
await self.router_cache.async_increment_cache(key=rpm_key, value=1)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_success += 1
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_get_available_deployments(
|
||||||
|
self,
|
||||||
|
model_group: str,
|
||||||
|
healthy_deployments: list,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async implementation of get deployments.
|
||||||
|
|
||||||
|
Reduces time to retrieve the tpm/rpm values from cache
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_available_deployments(
|
||||||
|
self,
|
||||||
|
model_group: str,
|
||||||
|
healthy_deployments: list,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns a deployment with the lowest TPM/RPM usage.
|
||||||
|
"""
|
||||||
|
# get list of potential deployments
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||||
|
)
|
||||||
|
|
||||||
|
current_minute = datetime.now().strftime("%H-%M")
|
||||||
|
tpm_keys = []
|
||||||
|
rpm_keys = []
|
||||||
|
for m in healthy_deployments:
|
||||||
|
if isinstance(m, dict):
|
||||||
|
id = m.get("model_info", {}).get(
|
||||||
|
"id"
|
||||||
|
) # a deployment should always have an 'id'. this is set in router.py
|
||||||
|
tpm_key = "{}:tpm:{}".format(id, current_minute)
|
||||||
|
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||||
|
|
||||||
|
tpm_keys.append(tpm_key)
|
||||||
|
rpm_keys.append(rpm_key)
|
||||||
|
|
||||||
|
tpm_values = self.router_cache.batch_get_cache(
|
||||||
|
keys=tpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
rpm_values = self.router_cache.batch_get_cache(
|
||||||
|
keys=rpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
|
||||||
|
tpm_dict = {} # {model_id: 1, ..}
|
||||||
|
for idx, key in enumerate(tpm_keys):
|
||||||
|
tpm_dict[tpm_keys[idx]] = tpm_values[idx]
|
||||||
|
|
||||||
|
rpm_dict = {} # {model_id: 1, ..}
|
||||||
|
for idx, key in enumerate(rpm_keys):
|
||||||
|
rpm_dict[rpm_keys[idx]] = rpm_values[idx]
|
||||||
|
|
||||||
|
try:
|
||||||
|
input_tokens = token_counter(messages=messages, text=input)
|
||||||
|
except:
|
||||||
|
input_tokens = 0
|
||||||
|
verbose_router_logger.debug(f"input_tokens={input_tokens}")
|
||||||
|
# -----------------------
|
||||||
|
# Find lowest used model
|
||||||
|
# ----------------------
|
||||||
|
lowest_tpm = float("inf")
|
||||||
|
|
||||||
|
if tpm_dict is None: # base case - none of the deployments have been used
|
||||||
|
# initialize a tpm dict with {model_id: 0}
|
||||||
|
tpm_dict = {}
|
||||||
|
for deployment in healthy_deployments:
|
||||||
|
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||||
|
else:
|
||||||
|
for d in healthy_deployments:
|
||||||
|
## if healthy deployment not yet used
|
||||||
|
if d["model_info"]["id"] not in tpm_dict:
|
||||||
|
tpm_dict[d["model_info"]["id"]] = 0
|
||||||
|
|
||||||
|
all_deployments = tpm_dict
|
||||||
|
|
||||||
|
deployment = None
|
||||||
|
for item, item_tpm in all_deployments.items():
|
||||||
|
## get the item from model list
|
||||||
|
_deployment = None
|
||||||
|
for m in healthy_deployments:
|
||||||
|
if item == m["model_info"]["id"]:
|
||||||
|
_deployment = m
|
||||||
|
|
||||||
|
if _deployment is None:
|
||||||
|
continue # skip to next one
|
||||||
|
|
||||||
|
_deployment_tpm = None
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = _deployment.get("tpm")
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = float("inf")
|
||||||
|
|
||||||
|
_deployment_rpm = None
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = _deployment.get("rpm")
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = float("inf")
|
||||||
|
|
||||||
|
if item_tpm + input_tokens > _deployment_tpm:
|
||||||
|
continue
|
||||||
|
elif (rpm_dict is not None and item in rpm_dict) and (
|
||||||
|
rpm_dict[item] + 1 > _deployment_rpm
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
elif item_tpm < lowest_tpm:
|
||||||
|
lowest_tpm = item_tpm
|
||||||
|
deployment = _deployment
|
||||||
|
print_verbose("returning picked lowest tpm/rpm deployment.")
|
||||||
|
return deployment
|
Loading…
Add table
Add a link
Reference in a new issue