Merge pull request #2942 from BerriAI/litellm_fix_router_loading

Router Async Improvements
This commit is contained in:
Krish Dholakia 2024-04-10 20:16:53 -07:00 committed by GitHub
commit 83e7ed94ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 746 additions and 60 deletions

View file

@ -81,9 +81,29 @@ class InMemoryCache(BaseCache):
return cached_response return cached_response
return None return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs) return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs):
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
def flush_cache(self): def flush_cache(self):
self.cache_dict.clear() self.cache_dict.clear()
self.ttl_dict.clear() self.ttl_dict.clear()
@ -246,6 +266,19 @@ class RedisCache(BaseCache):
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() await self.flush_cache_buffer()
async def async_increment(self, key, value: int, **kwargs):
_redis_client = self.init_async_client()
try:
async with _redis_client as redis_client:
await redis_client.incr(name=key, amount=value)
except Exception as e:
verbose_logger.error(
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
async def flush_cache_buffer(self): async def flush_cache_buffer(self):
print_verbose( print_verbose(
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}" f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
@ -283,6 +316,32 @@ class RedisCache(BaseCache):
traceback.print_exc() traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
def batch_get_cache(self, key_list) -> dict:
"""
Use Redis for bulk read operations
"""
key_value_dict = {}
try:
_keys = []
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key)
results = self.redis_client.mget(keys=_keys)
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results))
decoded_results = {
k.decode("utf-8"): self._get_cache_logic(v)
for k, v in key_value_dict.items()
}
return decoded_results
except Exception as e:
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key) key = self.check_and_fix_namespace(key=key)
@ -301,7 +360,7 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}" f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
) )
async def async_get_cache_pipeline(self, key_list) -> dict: async def async_batch_get_cache(self, key_list) -> dict:
""" """
Use Redis for bulk read operations Use Redis for bulk read operations
""" """
@ -309,14 +368,11 @@ class RedisCache(BaseCache):
key_value_dict = {} key_value_dict = {}
try: try:
async with _redis_client as redis_client: async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe: _keys = []
# Queue the get operations in the pipeline for all keys.
for cache_key in key_list: for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key) cache_key = self.check_and_fix_namespace(key=cache_key)
pipe.get(cache_key) # Queue GET command in pipeline _keys.append(cache_key)
results = await redis_client.mget(keys=_keys)
# Execute the pipeline and await the results.
results = await pipe.execute()
# Associate the results back with their keys. # Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'. # 'results' is a list of values corresponding to the order of keys in 'key_list'.
@ -897,6 +953,39 @@ class DualCache(BaseCache):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
for key, value in redis_result.items():
result[sublist_keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
async def async_get_cache(self, key, local_only: bool = False, **kwargs): async def async_get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
try: try:
@ -930,6 +1019,50 @@ class DualCache(BaseCache):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
async def async_batch_get_cache(
self, keys: list, local_only: bool = False, **kwargs
):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
keys, **kwargs
)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, **kwargs
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
sublist_dict = dict(zip(sublist_keys, redis_result))
for key, value in sublist_dict.items():
result[sublist_keys.index(key)] = value[key]
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
try: try:
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
@ -941,6 +1074,24 @@ class DualCache(BaseCache):
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc() traceback.print_exc()
async def async_increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
):
"""
Key - the key in cache
Value - int - the value you want to increment by
"""
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_increment(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
await self.redis_cache.async_increment(key, value, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
def flush_cache(self): def flush_cache(self):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache() self.in_memory_cache.flush_cache()

View file

@ -27,12 +27,11 @@ litellm_settings:
upperbound_key_generate_params: upperbound_key_generate_params:
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
# litellm_settings: router_settings:
# drop_params: True routing_strategy: usage-based-routing-v2
# max_budget: 800021 redis_host: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
# budget_duration: 30d redis_password: madeBerri@992
# # cache: true redis_port: 16337
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234

View file

@ -79,7 +79,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
self.print_verbose(f"redis keys: {keys}") self.print_verbose(f"redis keys: {keys}")
if len(keys) > 0: if len(keys) > 0:
key_value_dict = ( key_value_dict = (
await litellm.cache.cache.async_get_cache_pipeline( await litellm.cache.cache.async_batch_get_cache(
key_list=keys key_list=keys
) )
) )

View file

@ -425,9 +425,10 @@ def run_server(
) )
proxy_config = ProxyConfig() proxy_config = ProxyConfig()
_, _, general_settings = asyncio.run( _config = asyncio.run(proxy_config.get_config(config_file_path=config))
proxy_config.load_config(router=None, config_file_path=config) general_settings = _config.get("general_settings", {})
) if general_settings is None:
general_settings = {}
database_url = general_settings.get("database_url", None) database_url = general_settings.get("database_url", None)
db_connection_pool_limit = general_settings.get( db_connection_pool_limit = general_settings.get(
"database_connection_pool_limit", 100 "database_connection_pool_limit", 100

View file

@ -2335,6 +2335,7 @@ class ProxyConfig:
"background_health_checks", False "background_health_checks", False
) )
health_check_interval = general_settings.get("health_check_interval", 300) health_check_interval = general_settings.get("health_check_interval", 300)
router_params: dict = { router_params: dict = {
"cache_responses": litellm.cache "cache_responses": litellm.cache
!= None, # cache if user passed in cache values != None, # cache if user passed in cache values

View file

@ -11,9 +11,9 @@ import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
import litellm, openai import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache from litellm.caching import RedisCache, InMemoryCache, DualCache
import datetime as datetime_og
import logging, asyncio import logging, asyncio
import inspect, concurrent import inspect, concurrent
from openai import AsyncOpenAI from openai import AsyncOpenAI
@ -21,11 +21,12 @@ from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.llms.custom_httpx.azure_dall_e_2 import ( from litellm.llms.custom_httpx.azure_dall_e_2 import (
CustomHTTPTransport, CustomHTTPTransport,
AsyncCustomHTTPTransport, AsyncCustomHTTPTransport,
) )
from litellm.utils import ModelResponse, CustomStreamWrapper from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
import copy import copy
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
import logging import logging
@ -273,6 +274,12 @@ class Router:
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "usage-based-routing-v2":
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache, model_list=self.model_list
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
elif routing_strategy == "latency-based-routing": elif routing_strategy == "latency-based-routing":
self.lowestlatency_logger = LowestLatencyLoggingHandler( self.lowestlatency_logger = LowestLatencyLoggingHandler(
router_cache=self.cache, router_cache=self.cache,
@ -407,7 +414,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
) )
deployment = self.get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=messages, messages=messages,
specific_deployment=kwargs.pop("specific_deployment", None), specific_deployment=kwargs.pop("specific_deployment", None),
@ -581,7 +588,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
) )
deployment = self.get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": "prompt"}], messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None), specific_deployment=kwargs.pop("specific_deployment", None),
@ -681,7 +688,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}" f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
) )
deployment = self.get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": "prompt"}], messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None), specific_deployment=kwargs.pop("specific_deployment", None),
@ -761,7 +768,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _moderation()- model: {model}; kwargs: {kwargs}" f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
) )
deployment = self.get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
input=input, input=input,
specific_deployment=kwargs.pop("specific_deployment", None), specific_deployment=kwargs.pop("specific_deployment", None),
@ -904,7 +911,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
) )
deployment = self.get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
specific_deployment=kwargs.pop("specific_deployment", None), specific_deployment=kwargs.pop("specific_deployment", None),
@ -1070,7 +1077,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}" f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
) )
deployment = self.get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
input=input, input=input,
specific_deployment=kwargs.pop("specific_deployment", None), specific_deployment=kwargs.pop("specific_deployment", None),
@ -1598,7 +1605,8 @@ class Router:
if deployment is None: if deployment is None:
return return
current_minute = datetime.now().strftime("%H-%M") dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get current fails for deployment # get current fails for deployment
# update the number of failed calls # update the number of failed calls
# if it's > allowed fails # if it's > allowed fails
@ -1636,11 +1644,29 @@ class Router:
key=deployment, value=updated_fails, ttl=cooldown_time key=deployment, value=updated_fails, ttl=cooldown_time
) )
async def _async_get_cooldown_deployments(self):
"""
Async implementation of '_get_cooldown_deployments'
"""
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models"
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def _get_cooldown_deployments(self): def _get_cooldown_deployments(self):
""" """
Get the list of models being cooled down for this minute Get the list of models being cooled down for this minute
""" """
current_minute = datetime.now().strftime("%H-%M") dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get the current cooldown list for that minute # get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" cooldown_key = f"{current_minute}:cooldown_models"
@ -2065,6 +2091,34 @@ class Router:
local_only=True, local_only=True,
) # cache for 1 hr ) # cache for 1 hr
def _generate_model_id(self, model_group: str, litellm_params: dict):
"""
Helper function to consistently generate the same id for a deployment
- create a string from all the litellm params
- hash
- use hash as id
"""
concat_str = model_group
for k, v in litellm_params.items():
if isinstance(k, str):
concat_str += k
elif isinstance(k, dict):
concat_str += json.dumps(k)
else:
concat_str += str(k)
if isinstance(v, str):
concat_str += v
elif isinstance(v, dict):
concat_str += json.dumps(v)
else:
concat_str += str(v)
hash_object = hashlib.sha256(concat_str.encode())
return hash_object.hexdigest()
def set_model_list(self, model_list: list): def set_model_list(self, model_list: list):
original_model_list = copy.deepcopy(model_list) original_model_list = copy.deepcopy(model_list)
self.model_list = [] self.model_list = []
@ -2080,7 +2134,13 @@ class Router:
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
_litellm_params[k] = litellm.get_secret(v) _litellm_params[k] = litellm.get_secret(v)
_model_info = model.pop("model_info", {}) _model_info: dict = model.pop("model_info", {})
# check if model info has id
if "id" not in _model_info:
_id = self._generate_model_id(_model_name, _litellm_params)
_model_info["id"] = _id
deployment = Deployment( deployment = Deployment(
**model, **model,
model_name=_model_name, model_name=_model_name,
@ -2279,7 +2339,8 @@ class Router:
_rate_limit_error = False _rate_limit_error = False
## get model group RPM ## ## get model group RPM ##
current_minute = datetime.now().strftime("%H-%M") dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
rpm_key = f"{model}:rpm:{current_minute}" rpm_key = f"{model}:rpm:{current_minute}"
model_group_cache = ( model_group_cache = (
self.cache.get_cache(key=rpm_key, local_only=True) or {} self.cache.get_cache(key=rpm_key, local_only=True) or {}
@ -2364,7 +2425,7 @@ class Router:
return _returned_deployments return _returned_deployments
def get_available_deployment( def _common_checks_available_deployment(
self, self,
model: str, model: str,
messages: Optional[List[Dict[str, str]]] = None, messages: Optional[List[Dict[str, str]]] = None,
@ -2372,11 +2433,11 @@ class Router:
specific_deployment: Optional[bool] = False, specific_deployment: Optional[bool] = False,
): ):
""" """
Returns the deployment based on routing strategy Common checks for 'get_available_deployment' across sync + async call.
"""
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg If 'healthy_deployments' returned is None, this means the user chose a specific deployment
# When this was no explicit we had several issues with fallbacks timing out """
# check if aliases set on litellm model alias map
if specific_deployment == True: if specific_deployment == True:
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment # users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
for deployment in self.model_list: for deployment in self.model_list:
@ -2384,12 +2445,11 @@ class Router:
if deployment_model == model: if deployment_model == model:
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2 # User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
# return the first deployment where the `model` matches the specificed deployment name # return the first deployment where the `model` matches the specificed deployment name
return deployment return deployment, None
raise ValueError( raise ValueError(
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}" f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
) )
# check if aliases set on litellm model alias map
if model in self.model_group_alias: if model in self.model_group_alias:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}" f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
@ -2401,7 +2461,7 @@ class Router:
self.default_deployment self.default_deployment
) # self.default_deployment ) # self.default_deployment
updated_deployment["litellm_params"]["model"] = model updated_deployment["litellm_params"]["model"] = model
return updated_deployment return updated_deployment, None
## get healthy deployments ## get healthy deployments
### get all deployments ### get all deployments
@ -2416,6 +2476,118 @@ class Router:
f"initial list of deployments: {healthy_deployments}" f"initial list of deployments: {healthy_deployments}"
) )
verbose_router_logger.debug(
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
)
if len(healthy_deployments) == 0:
raise ValueError(f"No healthy deployment available, passed model={model}")
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
return model, healthy_deployments
async def async_get_available_deployment(
self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
):
"""
Async implementation of 'get_available_deployments'.
Allows all cache calls to be made async => 10x perf impact (8rps -> 100 rps).
"""
if (
self.routing_strategy != "usage-based-routing-v2"
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
return self.get_available_deployment(
model=model,
messages=messages,
input=input,
specific_deployment=specific_deployment,
)
model, healthy_deployments = self._common_checks_available_deployment(
model=model,
messages=messages,
input=input,
specific_deployment=specific_deployment,
)
if healthy_deployments is None:
return model
# filter out the deployments currently cooling down
deployments_to_remove = []
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
cooldown_deployments = await self._async_get_cooldown_deployments()
verbose_router_logger.debug(
f"async cooldown deployments: {cooldown_deployments}"
)
# Find deployments in model_list whose model_id is cooling down
for deployment in healthy_deployments:
deployment_id = deployment["model_info"]["id"]
if deployment_id in cooldown_deployments:
deployments_to_remove.append(deployment)
# remove unhealthy deployments from healthy deployments
for deployment in deployments_to_remove:
healthy_deployments.remove(deployment)
# filter pre-call checks
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model, healthy_deployments=healthy_deployments, messages=messages
)
if (
self.routing_strategy == "usage-based-routing-v2"
and self.lowesttpm_logger_v2 is not None
):
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments,
messages=messages,
input=input,
)
if deployment is None:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
)
raise ValueError(
f"No deployments available for selected model, passed model={model}"
)
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
)
return deployment
def get_available_deployment(
self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
):
"""
Returns the deployment based on routing strategy
"""
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
# When this was no explicit we had several issues with fallbacks timing out
model, healthy_deployments = self._common_checks_available_deployment(
model=model,
messages=messages,
input=input,
specific_deployment=specific_deployment,
)
if healthy_deployments is None:
return model
# filter out the deployments currently cooling down # filter out the deployments currently cooling down
deployments_to_remove = [] deployments_to_remove = []
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"] # cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
@ -2436,16 +2608,6 @@ class Router:
model=model, healthy_deployments=healthy_deployments, messages=messages model=model, healthy_deployments=healthy_deployments, messages=messages
) )
verbose_router_logger.debug(
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
)
if len(healthy_deployments) == 0:
raise ValueError(f"No healthy deployment available, passed model={model}")
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployment = self.leastbusy_logger.get_available_deployments( deployment = self.leastbusy_logger.get_available_deployments(
model_group=model, healthy_deployments=healthy_deployments model_group=model, healthy_deployments=healthy_deployments
@ -2507,7 +2669,16 @@ class Router:
messages=messages, messages=messages,
input=input, input=input,
) )
elif (
self.routing_strategy == "usage-based-routing-v2"
and self.lowesttpm_logger_v2 is not None
):
deployment = self.lowesttpm_logger_v2.get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments,
messages=messages,
input=input,
)
if deployment is None: if deployment is None:
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"

View file

@ -0,0 +1,325 @@
#### What this does ####
# identifies lowest tpm deployment
import dotenv, os, requests, random
from typing import Optional, Union, List, Dict
import datetime as datetime_og
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback, asyncio
from litellm import token_counter
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose, get_utc_datetime
class LowestTPMLoggingHandler_v2(CustomLogger):
"""
Updated version of TPM/RPM Logging.
Meant to work across instances.
Caches individual models, not model_groups
Uses batch get (redis.mget)
Increments tpm/rpm limit using redis.incr
"""
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list):
self.router_cache = router_cache
self.model_list = model_list
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = response_obj["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
# ------------
# Update usage
# ------------
## TPM
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = response_obj["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}"
rpm_key = f"{id}:rpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
await self.router_cache.async_increment_cache(
key=tpm_key, value=total_tokens
)
## RPM
await self.router_cache.async_increment_cache(key=rpm_key, value=1)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
pass
def _common_checks_available_deployment(
self,
model_group: str,
healthy_deployments: list,
tpm_keys: list,
tpm_values: list,
rpm_keys: list,
rpm_values: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Common checks for get available deployment, across sync + async implementations
"""
tpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(tpm_keys):
tpm_dict[tpm_keys[idx]] = tpm_values[idx]
rpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(rpm_keys):
rpm_dict[rpm_keys[idx]] = rpm_values[idx]
try:
input_tokens = token_counter(messages=messages, text=input)
except:
input_tokens = 0
verbose_router_logger.debug(f"input_tokens={input_tokens}")
# -----------------------
# Find lowest used model
# ----------------------
lowest_tpm = float("inf")
if tpm_dict is None: # base case - none of the deployments have been used
# initialize a tpm dict with {model_id: 0}
tpm_dict = {}
for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0
else:
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in tpm_dict:
tpm_dict[d["model_info"]["id"]] = 0
all_deployments = tpm_dict
deployment = None
for item, item_tpm in all_deployments.items():
## get the item from model list
_deployment = None
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = float("inf")
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = float("inf")
if item_tpm + input_tokens > _deployment_tpm:
continue
elif (rpm_dict is not None and item in rpm_dict) and (
rpm_dict[item] + 1 > _deployment_rpm
):
continue
elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm
deployment = _deployment
print_verbose("returning picked lowest tpm/rpm deployment.")
return deployment
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Async implementation of get deployments.
Reduces time to retrieve the tpm/rpm values from cache
"""
# get list of potential deployments
verbose_router_logger.debug(
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_keys = []
rpm_keys = []
for m in healthy_deployments:
if isinstance(m, dict):
id = m.get("model_info", {}).get(
"id"
) # a deployment should always have an 'id'. this is set in router.py
tpm_key = "{}:tpm:{}".format(id, current_minute)
rpm_key = "{}:rpm:{}".format(id, current_minute)
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)
tpm_values = await self.router_cache.async_batch_get_cache(
keys=tpm_keys
) # [1, 2, None, ..]
rpm_values = await self.router_cache.async_batch_get_cache(
keys=rpm_keys
) # [1, 2, None, ..]
return self._common_checks_available_deployment(
model_group=model_group,
healthy_deployments=healthy_deployments,
tpm_keys=tpm_keys,
tpm_values=tpm_values,
rpm_keys=rpm_keys,
rpm_values=rpm_values,
messages=messages,
input=input,
)
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Returns a deployment with the lowest TPM/RPM usage.
"""
# get list of potential deployments
verbose_router_logger.debug(
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_keys = []
rpm_keys = []
for m in healthy_deployments:
if isinstance(m, dict):
id = m.get("model_info", {}).get(
"id"
) # a deployment should always have an 'id'. this is set in router.py
tpm_key = "{}:tpm:{}".format(id, current_minute)
rpm_key = "{}:rpm:{}".format(id, current_minute)
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)
tpm_values = self.router_cache.batch_get_cache(
keys=tpm_keys
) # [1, 2, None, ..]
rpm_values = self.router_cache.batch_get_cache(
keys=rpm_keys
) # [1, 2, None, ..]
return self._common_checks_available_deployment(
model_group=model_group,
healthy_deployments=healthy_deployments,
tpm_keys=tpm_keys,
tpm_values=tpm_values,
rpm_keys=rpm_keys,
rpm_values=rpm_values,
messages=messages,
input=input,
)

View file

@ -932,6 +932,35 @@ def test_openai_completion_on_router():
# test_openai_completion_on_router() # test_openai_completion_on_router()
def test_consistent_model_id():
"""
- For a given model group + litellm params, assert the model id is always the same
Test on `_generate_model_id`
Test on `set_model_list`
Test on `_add_deployment`
"""
model_group = "gpt-3.5-turbo"
litellm_params = {
"model": "openai/my-fake-model",
"api_key": "my-fake-key",
"api_base": "https://openai-function-calling-workers.tasslexyz.workers.dev/",
"stream_timeout": 0.001,
}
id1 = Router()._generate_model_id(
model_group=model_group, litellm_params=litellm_params
)
id2 = Router()._generate_model_id(
model_group=model_group, litellm_params=litellm_params
)
assert id1 == id2
def test_reading_keys_os_environ(): def test_reading_keys_os_environ():
import openai import openai

View file

@ -47,6 +47,7 @@ class RouterConfig(BaseModel):
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
id: Optional[ id: Optional[
str str
@ -132,9 +133,11 @@ class Deployment(BaseModel):
litellm_params: LiteLLM_Params litellm_params: LiteLLM_Params
model_info: ModelInfo model_info: ModelInfo
def __init__(self, model_info: Optional[ModelInfo] = None, **params): def __init__(self, model_info: Optional[Union[ModelInfo, dict]] = None, **params):
if model_info is None: if model_info is None:
model_info = ModelInfo() model_info = ModelInfo()
elif isinstance(model_info, dict):
model_info = ModelInfo(**model_info)
super().__init__(model_info=model_info, **params) super().__init__(model_info=model_info, **params)
def to_json(self, **kwargs): def to_json(self, **kwargs):

View file

@ -1990,9 +1990,6 @@ class Logging:
else: else:
litellm.cache.add_cache(result, **kwargs) litellm.cache.add_cache(result, **kwargs)
if isinstance(callback, CustomLogger): # custom logger class if isinstance(callback, CustomLogger): # custom logger class
print_verbose(
f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}"
)
if self.stream == True: if self.stream == True:
if ( if (
"async_complete_streaming_response" "async_complete_streaming_response"
@ -2376,7 +2373,6 @@ def client(original_function):
if litellm.use_client or ( if litellm.use_client or (
"use_client" in kwargs and kwargs["use_client"] == True "use_client" in kwargs and kwargs["use_client"] == True
): ):
print_verbose(f"litedebugger initialized")
if "lite_debugger" not in litellm.input_callback: if "lite_debugger" not in litellm.input_callback:
litellm.input_callback.append("lite_debugger") litellm.input_callback.append("lite_debugger")
if "lite_debugger" not in litellm.success_callback: if "lite_debugger" not in litellm.success_callback:
@ -5912,6 +5908,16 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
return api_key return api_key
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
def get_max_tokens(model: str): def get_max_tokens(model: str):
""" """
Get the maximum number of output tokens allowed for a given model. Get the maximum number of output tokens allowed for a given model.