forked from phoenix/litellm-mirror
Merge pull request #2942 from BerriAI/litellm_fix_router_loading
Router Async Improvements
This commit is contained in:
commit
83e7ed94ce
10 changed files with 746 additions and 60 deletions
|
@ -81,9 +81,29 @@ class InMemoryCache(BaseCache):
|
||||||
return cached_response
|
return cached_response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def batch_get_cache(self, keys: list, **kwargs):
|
||||||
|
return_val = []
|
||||||
|
for k in keys:
|
||||||
|
val = self.get_cache(key=k, **kwargs)
|
||||||
|
return_val.append(val)
|
||||||
|
return return_val
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
return self.get_cache(key=key, **kwargs)
|
return self.get_cache(key=key, **kwargs)
|
||||||
|
|
||||||
|
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||||
|
return_val = []
|
||||||
|
for k in keys:
|
||||||
|
val = self.get_cache(key=k, **kwargs)
|
||||||
|
return_val.append(val)
|
||||||
|
return return_val
|
||||||
|
|
||||||
|
async def async_increment(self, key, value: int, **kwargs):
|
||||||
|
# get the value
|
||||||
|
init_value = await self.async_get_cache(key=key) or 0
|
||||||
|
value = init_value + value
|
||||||
|
await self.async_set_cache(key, value, **kwargs)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.cache_dict.clear()
|
self.cache_dict.clear()
|
||||||
self.ttl_dict.clear()
|
self.ttl_dict.clear()
|
||||||
|
@ -246,6 +266,19 @@ class RedisCache(BaseCache):
|
||||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||||
await self.flush_cache_buffer()
|
await self.flush_cache_buffer()
|
||||||
|
|
||||||
|
async def async_increment(self, key, value: int, **kwargs):
|
||||||
|
_redis_client = self.init_async_client()
|
||||||
|
try:
|
||||||
|
async with _redis_client as redis_client:
|
||||||
|
await redis_client.incr(name=key, amount=value)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
||||||
|
str(e),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def flush_cache_buffer(self):
|
async def flush_cache_buffer(self):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
|
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
|
||||||
|
@ -283,6 +316,32 @@ class RedisCache(BaseCache):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||||
|
|
||||||
|
def batch_get_cache(self, key_list) -> dict:
|
||||||
|
"""
|
||||||
|
Use Redis for bulk read operations
|
||||||
|
"""
|
||||||
|
key_value_dict = {}
|
||||||
|
try:
|
||||||
|
_keys = []
|
||||||
|
for cache_key in key_list:
|
||||||
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
|
_keys.append(cache_key)
|
||||||
|
results = self.redis_client.mget(keys=_keys)
|
||||||
|
|
||||||
|
# Associate the results back with their keys.
|
||||||
|
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||||
|
key_value_dict = dict(zip(key_list, results))
|
||||||
|
|
||||||
|
decoded_results = {
|
||||||
|
k.decode("utf-8"): self._get_cache_logic(v)
|
||||||
|
for k, v in key_value_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoded_results
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"Error occurred in pipeline read - {str(e)}")
|
||||||
|
return key_value_dict
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
key = self.check_and_fix_namespace(key=key)
|
key = self.check_and_fix_namespace(key=key)
|
||||||
|
@ -301,7 +360,7 @@ class RedisCache(BaseCache):
|
||||||
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
|
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_cache_pipeline(self, key_list) -> dict:
|
async def async_batch_get_cache(self, key_list) -> dict:
|
||||||
"""
|
"""
|
||||||
Use Redis for bulk read operations
|
Use Redis for bulk read operations
|
||||||
"""
|
"""
|
||||||
|
@ -309,14 +368,11 @@ class RedisCache(BaseCache):
|
||||||
key_value_dict = {}
|
key_value_dict = {}
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
_keys = []
|
||||||
# Queue the get operations in the pipeline for all keys.
|
for cache_key in key_list:
|
||||||
for cache_key in key_list:
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
_keys.append(cache_key)
|
||||||
pipe.get(cache_key) # Queue GET command in pipeline
|
results = await redis_client.mget(keys=_keys)
|
||||||
|
|
||||||
# Execute the pipeline and await the results.
|
|
||||||
results = await pipe.execute()
|
|
||||||
|
|
||||||
# Associate the results back with their keys.
|
# Associate the results back with their keys.
|
||||||
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||||
|
@ -897,6 +953,39 @@ class DualCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
|
||||||
|
try:
|
||||||
|
result = [None for _ in range(len(keys))]
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
|
||||||
|
|
||||||
|
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||||
|
if in_memory_result is not None:
|
||||||
|
result = in_memory_result
|
||||||
|
|
||||||
|
if None in result and self.redis_cache is not None and local_only == False:
|
||||||
|
"""
|
||||||
|
- for the none values in the result
|
||||||
|
- check the redis cache
|
||||||
|
"""
|
||||||
|
sublist_keys = [
|
||||||
|
key for key, value in zip(keys, result) if value is None
|
||||||
|
]
|
||||||
|
# If not found in in-memory cache, try fetching from Redis
|
||||||
|
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
|
||||||
|
if redis_result is not None:
|
||||||
|
# Update in-memory cache with the value from Redis
|
||||||
|
for key in redis_result:
|
||||||
|
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
|
||||||
|
|
||||||
|
for key, value in redis_result.items():
|
||||||
|
result[sublist_keys.index(key)] = value
|
||||||
|
|
||||||
|
print_verbose(f"async batch get cache: cache result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
||||||
# Try to fetch from in-memory cache first
|
# Try to fetch from in-memory cache first
|
||||||
try:
|
try:
|
||||||
|
@ -930,6 +1019,50 @@ class DualCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_batch_get_cache(
|
||||||
|
self, keys: list, local_only: bool = False, **kwargs
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
result = [None for _ in range(len(keys))]
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
|
||||||
|
keys, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||||
|
if in_memory_result is not None:
|
||||||
|
result = in_memory_result
|
||||||
|
|
||||||
|
if None in result and self.redis_cache is not None and local_only == False:
|
||||||
|
"""
|
||||||
|
- for the none values in the result
|
||||||
|
- check the redis cache
|
||||||
|
"""
|
||||||
|
sublist_keys = [
|
||||||
|
key for key, value in zip(keys, result) if value is None
|
||||||
|
]
|
||||||
|
# If not found in in-memory cache, try fetching from Redis
|
||||||
|
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||||
|
sublist_keys, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if redis_result is not None:
|
||||||
|
# Update in-memory cache with the value from Redis
|
||||||
|
for key in redis_result:
|
||||||
|
await self.in_memory_cache.async_set_cache(
|
||||||
|
key, redis_result[key], **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
sublist_dict = dict(zip(sublist_keys, redis_result))
|
||||||
|
|
||||||
|
for key, value in sublist_dict.items():
|
||||||
|
result[sublist_keys.index(key)] = value[key]
|
||||||
|
|
||||||
|
print_verbose(f"async batch get cache: cache result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||||
try:
|
try:
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
|
@ -941,6 +1074,24 @@ class DualCache(BaseCache):
|
||||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_increment_cache(
|
||||||
|
self, key, value: int, local_only: bool = False, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Key - the key in cache
|
||||||
|
|
||||||
|
Value - int - the value you want to increment by
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
await self.in_memory_cache.async_increment(key, value, **kwargs)
|
||||||
|
|
||||||
|
if self.redis_cache is not None and local_only == False:
|
||||||
|
await self.redis_cache.async_increment(key, value, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
self.in_memory_cache.flush_cache()
|
self.in_memory_cache.flush_cache()
|
||||||
|
|
|
@ -26,13 +26,12 @@ litellm_settings:
|
||||||
success_callback: ["prometheus"]
|
success_callback: ["prometheus"]
|
||||||
upperbound_key_generate_params:
|
upperbound_key_generate_params:
|
||||||
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
|
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
|
||||||
|
|
||||||
# litellm_settings:
|
router_settings:
|
||||||
# drop_params: True
|
routing_strategy: usage-based-routing-v2
|
||||||
# max_budget: 800021
|
redis_host: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
|
||||||
# budget_duration: 30d
|
redis_password: madeBerri@992
|
||||||
# # cache: true
|
redis_port: 16337
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
|
|
@ -79,7 +79,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
||||||
self.print_verbose(f"redis keys: {keys}")
|
self.print_verbose(f"redis keys: {keys}")
|
||||||
if len(keys) > 0:
|
if len(keys) > 0:
|
||||||
key_value_dict = (
|
key_value_dict = (
|
||||||
await litellm.cache.cache.async_get_cache_pipeline(
|
await litellm.cache.cache.async_batch_get_cache(
|
||||||
key_list=keys
|
key_list=keys
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -425,9 +425,10 @@ def run_server(
|
||||||
)
|
)
|
||||||
|
|
||||||
proxy_config = ProxyConfig()
|
proxy_config = ProxyConfig()
|
||||||
_, _, general_settings = asyncio.run(
|
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
|
||||||
proxy_config.load_config(router=None, config_file_path=config)
|
general_settings = _config.get("general_settings", {})
|
||||||
)
|
if general_settings is None:
|
||||||
|
general_settings = {}
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
db_connection_pool_limit = general_settings.get(
|
db_connection_pool_limit = general_settings.get(
|
||||||
"database_connection_pool_limit", 100
|
"database_connection_pool_limit", 100
|
||||||
|
|
|
@ -2335,6 +2335,7 @@ class ProxyConfig:
|
||||||
"background_health_checks", False
|
"background_health_checks", False
|
||||||
)
|
)
|
||||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||||
|
|
||||||
router_params: dict = {
|
router_params: dict = {
|
||||||
"cache_responses": litellm.cache
|
"cache_responses": litellm.cache
|
||||||
!= None, # cache if user passed in cache values
|
!= None, # cache if user passed in cache values
|
||||||
|
|
|
@ -11,9 +11,9 @@ import copy, httpx
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
|
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
|
||||||
import random, threading, time, traceback, uuid
|
import random, threading, time, traceback, uuid
|
||||||
import litellm, openai
|
import litellm, openai, hashlib, json
|
||||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||||
|
import datetime as datetime_og
|
||||||
import logging, asyncio
|
import logging, asyncio
|
||||||
import inspect, concurrent
|
import inspect, concurrent
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
@ -21,11 +21,12 @@ from collections import defaultdict
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||||
CustomHTTPTransport,
|
CustomHTTPTransport,
|
||||||
AsyncCustomHTTPTransport,
|
AsyncCustomHTTPTransport,
|
||||||
)
|
)
|
||||||
from litellm.utils import ModelResponse, CustomStreamWrapper
|
from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
import logging
|
import logging
|
||||||
|
@ -273,6 +274,12 @@ class Router:
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
||||||
|
elif routing_strategy == "usage-based-routing-v2":
|
||||||
|
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
|
||||||
|
router_cache=self.cache, model_list=self.model_list
|
||||||
|
)
|
||||||
|
if isinstance(litellm.callbacks, list):
|
||||||
|
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
|
||||||
elif routing_strategy == "latency-based-routing":
|
elif routing_strategy == "latency-based-routing":
|
||||||
self.lowestlatency_logger = LowestLatencyLoggingHandler(
|
self.lowestlatency_logger = LowestLatencyLoggingHandler(
|
||||||
router_cache=self.cache,
|
router_cache=self.cache,
|
||||||
|
@ -407,7 +414,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -581,7 +588,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": "prompt"}],
|
messages=[{"role": "user", "content": "prompt"}],
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -681,7 +688,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
|
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": "prompt"}],
|
messages=[{"role": "user", "content": "prompt"}],
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -761,7 +768,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
|
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -904,7 +911,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -1070,7 +1077,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
|
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -1598,7 +1605,8 @@ class Router:
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
current_minute = datetime.now().strftime("%H-%M")
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
# get current fails for deployment
|
# get current fails for deployment
|
||||||
# update the number of failed calls
|
# update the number of failed calls
|
||||||
# if it's > allowed fails
|
# if it's > allowed fails
|
||||||
|
@ -1636,11 +1644,29 @@ class Router:
|
||||||
key=deployment, value=updated_fails, ttl=cooldown_time
|
key=deployment, value=updated_fails, ttl=cooldown_time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _async_get_cooldown_deployments(self):
|
||||||
|
"""
|
||||||
|
Async implementation of '_get_cooldown_deployments'
|
||||||
|
"""
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
# get the current cooldown list for that minute
|
||||||
|
cooldown_key = f"{current_minute}:cooldown_models"
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Return cooldown models
|
||||||
|
# ----------------------
|
||||||
|
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
|
||||||
|
|
||||||
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
|
return cooldown_models
|
||||||
|
|
||||||
def _get_cooldown_deployments(self):
|
def _get_cooldown_deployments(self):
|
||||||
"""
|
"""
|
||||||
Get the list of models being cooled down for this minute
|
Get the list of models being cooled down for this minute
|
||||||
"""
|
"""
|
||||||
current_minute = datetime.now().strftime("%H-%M")
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
# get the current cooldown list for that minute
|
# get the current cooldown list for that minute
|
||||||
cooldown_key = f"{current_minute}:cooldown_models"
|
cooldown_key = f"{current_minute}:cooldown_models"
|
||||||
|
|
||||||
|
@ -2065,6 +2091,34 @@ class Router:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
def _generate_model_id(self, model_group: str, litellm_params: dict):
|
||||||
|
"""
|
||||||
|
Helper function to consistently generate the same id for a deployment
|
||||||
|
|
||||||
|
- create a string from all the litellm params
|
||||||
|
- hash
|
||||||
|
- use hash as id
|
||||||
|
"""
|
||||||
|
concat_str = model_group
|
||||||
|
for k, v in litellm_params.items():
|
||||||
|
if isinstance(k, str):
|
||||||
|
concat_str += k
|
||||||
|
elif isinstance(k, dict):
|
||||||
|
concat_str += json.dumps(k)
|
||||||
|
else:
|
||||||
|
concat_str += str(k)
|
||||||
|
|
||||||
|
if isinstance(v, str):
|
||||||
|
concat_str += v
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
concat_str += json.dumps(v)
|
||||||
|
else:
|
||||||
|
concat_str += str(v)
|
||||||
|
|
||||||
|
hash_object = hashlib.sha256(concat_str.encode())
|
||||||
|
|
||||||
|
return hash_object.hexdigest()
|
||||||
|
|
||||||
def set_model_list(self, model_list: list):
|
def set_model_list(self, model_list: list):
|
||||||
original_model_list = copy.deepcopy(model_list)
|
original_model_list = copy.deepcopy(model_list)
|
||||||
self.model_list = []
|
self.model_list = []
|
||||||
|
@ -2080,7 +2134,13 @@ class Router:
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
_litellm_params[k] = litellm.get_secret(v)
|
_litellm_params[k] = litellm.get_secret(v)
|
||||||
|
|
||||||
_model_info = model.pop("model_info", {})
|
_model_info: dict = model.pop("model_info", {})
|
||||||
|
|
||||||
|
# check if model info has id
|
||||||
|
if "id" not in _model_info:
|
||||||
|
_id = self._generate_model_id(_model_name, _litellm_params)
|
||||||
|
_model_info["id"] = _id
|
||||||
|
|
||||||
deployment = Deployment(
|
deployment = Deployment(
|
||||||
**model,
|
**model,
|
||||||
model_name=_model_name,
|
model_name=_model_name,
|
||||||
|
@ -2279,7 +2339,8 @@ class Router:
|
||||||
_rate_limit_error = False
|
_rate_limit_error = False
|
||||||
|
|
||||||
## get model group RPM ##
|
## get model group RPM ##
|
||||||
current_minute = datetime.now().strftime("%H-%M")
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
rpm_key = f"{model}:rpm:{current_minute}"
|
rpm_key = f"{model}:rpm:{current_minute}"
|
||||||
model_group_cache = (
|
model_group_cache = (
|
||||||
self.cache.get_cache(key=rpm_key, local_only=True) or {}
|
self.cache.get_cache(key=rpm_key, local_only=True) or {}
|
||||||
|
@ -2364,7 +2425,7 @@ class Router:
|
||||||
|
|
||||||
return _returned_deployments
|
return _returned_deployments
|
||||||
|
|
||||||
def get_available_deployment(
|
def _common_checks_available_deployment(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Optional[List[Dict[str, str]]] = None,
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
@ -2372,11 +2433,11 @@ class Router:
|
||||||
specific_deployment: Optional[bool] = False,
|
specific_deployment: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Returns the deployment based on routing strategy
|
Common checks for 'get_available_deployment' across sync + async call.
|
||||||
"""
|
|
||||||
|
|
||||||
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
|
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
|
||||||
# When this was no explicit we had several issues with fallbacks timing out
|
"""
|
||||||
|
# check if aliases set on litellm model alias map
|
||||||
if specific_deployment == True:
|
if specific_deployment == True:
|
||||||
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
||||||
for deployment in self.model_list:
|
for deployment in self.model_list:
|
||||||
|
@ -2384,12 +2445,11 @@ class Router:
|
||||||
if deployment_model == model:
|
if deployment_model == model:
|
||||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
||||||
# return the first deployment where the `model` matches the specificed deployment name
|
# return the first deployment where the `model` matches the specificed deployment name
|
||||||
return deployment
|
return deployment, None
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
|
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if aliases set on litellm model alias map
|
|
||||||
if model in self.model_group_alias:
|
if model in self.model_group_alias:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
|
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
|
||||||
|
@ -2401,7 +2461,7 @@ class Router:
|
||||||
self.default_deployment
|
self.default_deployment
|
||||||
) # self.default_deployment
|
) # self.default_deployment
|
||||||
updated_deployment["litellm_params"]["model"] = model
|
updated_deployment["litellm_params"]["model"] = model
|
||||||
return updated_deployment
|
return updated_deployment, None
|
||||||
|
|
||||||
## get healthy deployments
|
## get healthy deployments
|
||||||
### get all deployments
|
### get all deployments
|
||||||
|
@ -2416,6 +2476,118 @@ class Router:
|
||||||
f"initial list of deployments: {healthy_deployments}"
|
f"initial list of deployments: {healthy_deployments}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
|
||||||
|
)
|
||||||
|
if len(healthy_deployments) == 0:
|
||||||
|
raise ValueError(f"No healthy deployment available, passed model={model}")
|
||||||
|
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||||
|
model = litellm.model_alias_map[
|
||||||
|
model
|
||||||
|
] # update the model to the actual value if an alias has been passed in
|
||||||
|
|
||||||
|
return model, healthy_deployments
|
||||||
|
|
||||||
|
async def async_get_available_deployment(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
specific_deployment: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async implementation of 'get_available_deployments'.
|
||||||
|
|
||||||
|
Allows all cache calls to be made async => 10x perf impact (8rps -> 100 rps).
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
self.routing_strategy != "usage-based-routing-v2"
|
||||||
|
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
|
||||||
|
return self.get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=specific_deployment,
|
||||||
|
)
|
||||||
|
|
||||||
|
model, healthy_deployments = self._common_checks_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=specific_deployment,
|
||||||
|
)
|
||||||
|
|
||||||
|
if healthy_deployments is None:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# filter out the deployments currently cooling down
|
||||||
|
deployments_to_remove = []
|
||||||
|
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
||||||
|
cooldown_deployments = await self._async_get_cooldown_deployments()
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"async cooldown deployments: {cooldown_deployments}"
|
||||||
|
)
|
||||||
|
# Find deployments in model_list whose model_id is cooling down
|
||||||
|
for deployment in healthy_deployments:
|
||||||
|
deployment_id = deployment["model_info"]["id"]
|
||||||
|
if deployment_id in cooldown_deployments:
|
||||||
|
deployments_to_remove.append(deployment)
|
||||||
|
# remove unhealthy deployments from healthy deployments
|
||||||
|
for deployment in deployments_to_remove:
|
||||||
|
healthy_deployments.remove(deployment)
|
||||||
|
|
||||||
|
# filter pre-call checks
|
||||||
|
if self.enable_pre_call_checks and messages is not None:
|
||||||
|
healthy_deployments = self._pre_call_checks(
|
||||||
|
model=model, healthy_deployments=healthy_deployments, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.routing_strategy == "usage-based-routing-v2"
|
||||||
|
and self.lowesttpm_logger_v2 is not None
|
||||||
|
):
|
||||||
|
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
|
||||||
|
model_group=model,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
if deployment is None:
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"No deployments available for selected model, passed model={model}"
|
||||||
|
)
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||||
|
)
|
||||||
|
return deployment
|
||||||
|
|
||||||
|
def get_available_deployment(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
specific_deployment: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns the deployment based on routing strategy
|
||||||
|
"""
|
||||||
|
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
|
||||||
|
# When this was no explicit we had several issues with fallbacks timing out
|
||||||
|
|
||||||
|
model, healthy_deployments = self._common_checks_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=specific_deployment,
|
||||||
|
)
|
||||||
|
|
||||||
|
if healthy_deployments is None:
|
||||||
|
return model
|
||||||
|
|
||||||
# filter out the deployments currently cooling down
|
# filter out the deployments currently cooling down
|
||||||
deployments_to_remove = []
|
deployments_to_remove = []
|
||||||
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
||||||
|
@ -2436,16 +2608,6 @@ class Router:
|
||||||
model=model, healthy_deployments=healthy_deployments, messages=messages
|
model=model, healthy_deployments=healthy_deployments, messages=messages
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_router_logger.debug(
|
|
||||||
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
|
|
||||||
)
|
|
||||||
if len(healthy_deployments) == 0:
|
|
||||||
raise ValueError(f"No healthy deployment available, passed model={model}")
|
|
||||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
|
||||||
model = litellm.model_alias_map[
|
|
||||||
model
|
|
||||||
] # update the model to the actual value if an alias has been passed in
|
|
||||||
|
|
||||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||||
deployment = self.leastbusy_logger.get_available_deployments(
|
deployment = self.leastbusy_logger.get_available_deployments(
|
||||||
model_group=model, healthy_deployments=healthy_deployments
|
model_group=model, healthy_deployments=healthy_deployments
|
||||||
|
@ -2507,7 +2669,16 @@ class Router:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
self.routing_strategy == "usage-based-routing-v2"
|
||||||
|
and self.lowesttpm_logger_v2 is not None
|
||||||
|
):
|
||||||
|
deployment = self.lowesttpm_logger_v2.get_available_deployments(
|
||||||
|
model_group=model,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, No deployment available"
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
|
|
325
litellm/router_strategy/lowest_tpm_rpm_v2.py
Normal file
325
litellm/router_strategy/lowest_tpm_rpm_v2.py
Normal file
|
@ -0,0 +1,325 @@
|
||||||
|
#### What this does ####
|
||||||
|
# identifies lowest tpm deployment
|
||||||
|
|
||||||
|
import dotenv, os, requests, random
|
||||||
|
from typing import Optional, Union, List, Dict
|
||||||
|
import datetime as datetime_og
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
import traceback, asyncio
|
||||||
|
from litellm import token_counter
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm._logging import verbose_router_logger
|
||||||
|
from litellm.utils import print_verbose, get_utc_datetime
|
||||||
|
|
||||||
|
|
||||||
|
class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
|
"""
|
||||||
|
Updated version of TPM/RPM Logging.
|
||||||
|
|
||||||
|
Meant to work across instances.
|
||||||
|
|
||||||
|
Caches individual models, not model_groups
|
||||||
|
|
||||||
|
Uses batch get (redis.mget)
|
||||||
|
|
||||||
|
Increments tpm/rpm limit using redis.incr
|
||||||
|
"""
|
||||||
|
|
||||||
|
test_flag: bool = False
|
||||||
|
logged_success: int = 0
|
||||||
|
logged_failure: int = 0
|
||||||
|
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||||
|
|
||||||
|
def __init__(self, router_cache: DualCache, model_list: list):
|
||||||
|
self.router_cache = router_cache
|
||||||
|
self.model_list = model_list
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
Update TPM/RPM usage on success
|
||||||
|
"""
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
elif isinstance(id, int):
|
||||||
|
id = str(id)
|
||||||
|
|
||||||
|
total_tokens = response_obj["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
tpm_key = f"{model_group}:tpm:{current_minute}"
|
||||||
|
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||||
|
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||||
|
|
||||||
|
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
|
||||||
|
|
||||||
|
## RPM
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||||
|
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||||
|
|
||||||
|
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_success += 1
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
Update TPM/RPM usage on success
|
||||||
|
"""
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
elif isinstance(id, int):
|
||||||
|
id = str(id)
|
||||||
|
|
||||||
|
total_tokens = response_obj["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime(
|
||||||
|
"%H-%M"
|
||||||
|
) # use the same timezone regardless of system clock
|
||||||
|
|
||||||
|
tpm_key = f"{id}:tpm:{current_minute}"
|
||||||
|
rpm_key = f"{id}:rpm:{current_minute}"
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage
|
||||||
|
# ------------
|
||||||
|
# update cache
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
await self.router_cache.async_increment_cache(
|
||||||
|
key=tpm_key, value=total_tokens
|
||||||
|
)
|
||||||
|
## RPM
|
||||||
|
await self.router_cache.async_increment_cache(key=rpm_key, value=1)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_success += 1
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _common_checks_available_deployment(
|
||||||
|
self,
|
||||||
|
model_group: str,
|
||||||
|
healthy_deployments: list,
|
||||||
|
tpm_keys: list,
|
||||||
|
tpm_values: list,
|
||||||
|
rpm_keys: list,
|
||||||
|
rpm_values: list,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Common checks for get available deployment, across sync + async implementations
|
||||||
|
"""
|
||||||
|
tpm_dict = {} # {model_id: 1, ..}
|
||||||
|
for idx, key in enumerate(tpm_keys):
|
||||||
|
tpm_dict[tpm_keys[idx]] = tpm_values[idx]
|
||||||
|
|
||||||
|
rpm_dict = {} # {model_id: 1, ..}
|
||||||
|
for idx, key in enumerate(rpm_keys):
|
||||||
|
rpm_dict[rpm_keys[idx]] = rpm_values[idx]
|
||||||
|
|
||||||
|
try:
|
||||||
|
input_tokens = token_counter(messages=messages, text=input)
|
||||||
|
except:
|
||||||
|
input_tokens = 0
|
||||||
|
verbose_router_logger.debug(f"input_tokens={input_tokens}")
|
||||||
|
# -----------------------
|
||||||
|
# Find lowest used model
|
||||||
|
# ----------------------
|
||||||
|
lowest_tpm = float("inf")
|
||||||
|
|
||||||
|
if tpm_dict is None: # base case - none of the deployments have been used
|
||||||
|
# initialize a tpm dict with {model_id: 0}
|
||||||
|
tpm_dict = {}
|
||||||
|
for deployment in healthy_deployments:
|
||||||
|
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||||
|
else:
|
||||||
|
for d in healthy_deployments:
|
||||||
|
## if healthy deployment not yet used
|
||||||
|
if d["model_info"]["id"] not in tpm_dict:
|
||||||
|
tpm_dict[d["model_info"]["id"]] = 0
|
||||||
|
|
||||||
|
all_deployments = tpm_dict
|
||||||
|
|
||||||
|
deployment = None
|
||||||
|
for item, item_tpm in all_deployments.items():
|
||||||
|
## get the item from model list
|
||||||
|
_deployment = None
|
||||||
|
for m in healthy_deployments:
|
||||||
|
if item == m["model_info"]["id"]:
|
||||||
|
_deployment = m
|
||||||
|
|
||||||
|
if _deployment is None:
|
||||||
|
continue # skip to next one
|
||||||
|
|
||||||
|
_deployment_tpm = None
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = _deployment.get("tpm")
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = float("inf")
|
||||||
|
|
||||||
|
_deployment_rpm = None
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = _deployment.get("rpm")
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = float("inf")
|
||||||
|
|
||||||
|
if item_tpm + input_tokens > _deployment_tpm:
|
||||||
|
continue
|
||||||
|
elif (rpm_dict is not None and item in rpm_dict) and (
|
||||||
|
rpm_dict[item] + 1 > _deployment_rpm
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
elif item_tpm < lowest_tpm:
|
||||||
|
lowest_tpm = item_tpm
|
||||||
|
deployment = _deployment
|
||||||
|
print_verbose("returning picked lowest tpm/rpm deployment.")
|
||||||
|
return deployment
|
||||||
|
|
||||||
|
async def async_get_available_deployments(
|
||||||
|
self,
|
||||||
|
model_group: str,
|
||||||
|
healthy_deployments: list,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async implementation of get deployments.
|
||||||
|
|
||||||
|
Reduces time to retrieve the tpm/rpm values from cache
|
||||||
|
"""
|
||||||
|
# get list of potential deployments
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||||
|
)
|
||||||
|
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
tpm_keys = []
|
||||||
|
rpm_keys = []
|
||||||
|
for m in healthy_deployments:
|
||||||
|
if isinstance(m, dict):
|
||||||
|
id = m.get("model_info", {}).get(
|
||||||
|
"id"
|
||||||
|
) # a deployment should always have an 'id'. this is set in router.py
|
||||||
|
tpm_key = "{}:tpm:{}".format(id, current_minute)
|
||||||
|
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||||
|
|
||||||
|
tpm_keys.append(tpm_key)
|
||||||
|
rpm_keys.append(rpm_key)
|
||||||
|
|
||||||
|
tpm_values = await self.router_cache.async_batch_get_cache(
|
||||||
|
keys=tpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
rpm_values = await self.router_cache.async_batch_get_cache(
|
||||||
|
keys=rpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
|
||||||
|
return self._common_checks_available_deployment(
|
||||||
|
model_group=model_group,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
tpm_keys=tpm_keys,
|
||||||
|
tpm_values=tpm_values,
|
||||||
|
rpm_keys=rpm_keys,
|
||||||
|
rpm_values=rpm_values,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_available_deployments(
|
||||||
|
self,
|
||||||
|
model_group: str,
|
||||||
|
healthy_deployments: list,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns a deployment with the lowest TPM/RPM usage.
|
||||||
|
"""
|
||||||
|
# get list of potential deployments
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||||
|
)
|
||||||
|
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
tpm_keys = []
|
||||||
|
rpm_keys = []
|
||||||
|
for m in healthy_deployments:
|
||||||
|
if isinstance(m, dict):
|
||||||
|
id = m.get("model_info", {}).get(
|
||||||
|
"id"
|
||||||
|
) # a deployment should always have an 'id'. this is set in router.py
|
||||||
|
tpm_key = "{}:tpm:{}".format(id, current_minute)
|
||||||
|
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||||
|
|
||||||
|
tpm_keys.append(tpm_key)
|
||||||
|
rpm_keys.append(rpm_key)
|
||||||
|
|
||||||
|
tpm_values = self.router_cache.batch_get_cache(
|
||||||
|
keys=tpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
rpm_values = self.router_cache.batch_get_cache(
|
||||||
|
keys=rpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
|
||||||
|
return self._common_checks_available_deployment(
|
||||||
|
model_group=model_group,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
tpm_keys=tpm_keys,
|
||||||
|
tpm_values=tpm_values,
|
||||||
|
rpm_keys=rpm_keys,
|
||||||
|
rpm_values=rpm_values,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
|
@ -932,6 +932,35 @@ def test_openai_completion_on_router():
|
||||||
# test_openai_completion_on_router()
|
# test_openai_completion_on_router()
|
||||||
|
|
||||||
|
|
||||||
|
def test_consistent_model_id():
|
||||||
|
"""
|
||||||
|
- For a given model group + litellm params, assert the model id is always the same
|
||||||
|
|
||||||
|
Test on `_generate_model_id`
|
||||||
|
|
||||||
|
Test on `set_model_list`
|
||||||
|
|
||||||
|
Test on `_add_deployment`
|
||||||
|
"""
|
||||||
|
model_group = "gpt-3.5-turbo"
|
||||||
|
litellm_params = {
|
||||||
|
"model": "openai/my-fake-model",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
"api_base": "https://openai-function-calling-workers.tasslexyz.workers.dev/",
|
||||||
|
"stream_timeout": 0.001,
|
||||||
|
}
|
||||||
|
|
||||||
|
id1 = Router()._generate_model_id(
|
||||||
|
model_group=model_group, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
id2 = Router()._generate_model_id(
|
||||||
|
model_group=model_group, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
assert id1 == id2
|
||||||
|
|
||||||
|
|
||||||
def test_reading_keys_os_environ():
|
def test_reading_keys_os_environ():
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ class ModelConfig(BaseModel):
|
||||||
rpm: int
|
rpm: int
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
class RouterConfig(BaseModel):
|
class RouterConfig(BaseModel):
|
||||||
|
@ -45,7 +45,8 @@ class RouterConfig(BaseModel):
|
||||||
] = "simple-shuffle"
|
] = "simple-shuffle"
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
id: Optional[
|
id: Optional[
|
||||||
|
@ -132,9 +133,11 @@ class Deployment(BaseModel):
|
||||||
litellm_params: LiteLLM_Params
|
litellm_params: LiteLLM_Params
|
||||||
model_info: ModelInfo
|
model_info: ModelInfo
|
||||||
|
|
||||||
def __init__(self, model_info: Optional[ModelInfo] = None, **params):
|
def __init__(self, model_info: Optional[Union[ModelInfo, dict]] = None, **params):
|
||||||
if model_info is None:
|
if model_info is None:
|
||||||
model_info = ModelInfo()
|
model_info = ModelInfo()
|
||||||
|
elif isinstance(model_info, dict):
|
||||||
|
model_info = ModelInfo(**model_info)
|
||||||
super().__init__(model_info=model_info, **params)
|
super().__init__(model_info=model_info, **params)
|
||||||
|
|
||||||
def to_json(self, **kwargs):
|
def to_json(self, **kwargs):
|
||||||
|
@ -146,7 +149,7 @@ class Deployment(BaseModel):
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# Define custom behavior for the 'in' operator
|
# Define custom behavior for the 'in' operator
|
||||||
|
|
|
@ -1990,9 +1990,6 @@ class Logging:
|
||||||
else:
|
else:
|
||||||
litellm.cache.add_cache(result, **kwargs)
|
litellm.cache.add_cache(result, **kwargs)
|
||||||
if isinstance(callback, CustomLogger): # custom logger class
|
if isinstance(callback, CustomLogger): # custom logger class
|
||||||
print_verbose(
|
|
||||||
f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}"
|
|
||||||
)
|
|
||||||
if self.stream == True:
|
if self.stream == True:
|
||||||
if (
|
if (
|
||||||
"async_complete_streaming_response"
|
"async_complete_streaming_response"
|
||||||
|
@ -2376,7 +2373,6 @@ def client(original_function):
|
||||||
if litellm.use_client or (
|
if litellm.use_client or (
|
||||||
"use_client" in kwargs and kwargs["use_client"] == True
|
"use_client" in kwargs and kwargs["use_client"] == True
|
||||||
):
|
):
|
||||||
print_verbose(f"litedebugger initialized")
|
|
||||||
if "lite_debugger" not in litellm.input_callback:
|
if "lite_debugger" not in litellm.input_callback:
|
||||||
litellm.input_callback.append("lite_debugger")
|
litellm.input_callback.append("lite_debugger")
|
||||||
if "lite_debugger" not in litellm.success_callback:
|
if "lite_debugger" not in litellm.success_callback:
|
||||||
|
@ -5912,6 +5908,16 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
def get_utc_datetime():
|
||||||
|
import datetime as dt
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
if hasattr(dt, "UTC"):
|
||||||
|
return datetime.now(dt.UTC) # type: ignore
|
||||||
|
else:
|
||||||
|
return datetime.utcnow() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def get_max_tokens(model: str):
|
def get_max_tokens(model: str):
|
||||||
"""
|
"""
|
||||||
Get the maximum number of output tokens allowed for a given model.
|
Get the maximum number of output tokens allowed for a given model.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue