fix(router.py): fix caching for tracking cooldowns + usage

This commit is contained in:
Krrish Dholakia 2023-11-23 11:13:24 -08:00
parent 94c1d71b2c
commit 61fc76a8c4
5 changed files with 148 additions and 75 deletions

View file

@ -89,6 +89,7 @@ const sidebars = {
"routing", "routing",
"rules", "rules",
"set_keys", "set_keys",
"budget_manager",
"completion/token_usage", "completion/token_usage",
{ {
type: 'category', type: 'category',
@ -157,7 +158,6 @@ const sidebars = {
label: 'Extras', label: 'Extras',
items: [ items: [
'extras/contributing', 'extras/contributing',
"budget_manager",
"proxy_server", "proxy_server",
{ {
type: "category", type: "category",

View file

@ -23,6 +23,9 @@ def get_prompt(*args, **kwargs):
return prompt return prompt
return None return None
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
class BaseCache: class BaseCache:
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
@ -32,6 +35,34 @@ class BaseCache:
raise NotImplementedError raise NotImplementedError
class InMemoryCache(BaseCache):
def __init__(self):
# if users don't provider one, use the default litellm cache
self.cache_dict = {}
self.ttl_dict = {}
def set_cache(self, key, value, **kwargs):
self.cache_dict[key] = value
if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
if isinstance(cached_response, dict):
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
return None
class RedisCache(BaseCache): class RedisCache(BaseCache):
def __init__(self, host, port, password): def __init__(self, host, port, password):
import redis import redis
@ -65,7 +96,58 @@ class RedisCache(BaseCache):
traceback.print_exc() traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
class DualCache(BaseCache):
"""
This updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
"""
def __init__(self, in_memory_cache: InMemoryCache =None, redis_cache: RedisCache =None) -> None:
super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache()
# If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache
def set_cache(self, key, value, **kwargs):
# Update both Redis and in-memory cache
try:
print_verbose(f"set cache: key: {key}; value: {value}")
if self.in_memory_cache is not None:
self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None:
self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e:
print_verbose(e)
def get_cache(self, key, **kwargs):
# Try to fetch from in-memory cache first
try:
print_verbose(f"get cache: cache key: {key}")
result = None
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
if in_memory_result is not None:
result = in_memory_result
if self.redis_cache is not None:
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs)
if redis_result is not None:
# Update in-memory cache with the value from Redis
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
result = redis_result
print_verbose(f"get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
#### DEPRECATED ####
class HostedCache(BaseCache): class HostedCache(BaseCache):
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
if "ttl" in kwargs: if "ttl" in kwargs:
@ -91,33 +173,7 @@ class HostedCache(BaseCache):
return cached_response return cached_response
class InMemoryCache(BaseCache): #### LiteLLM.Completion Cache ####
def __init__(self):
# if users don't provider one, use the default litellm cache
self.cache_dict = {}
self.ttl_dict = {}
def set_cache(self, key, value, **kwargs):
self.cache_dict[key] = value
if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
return None
class Cache: class Cache:
def __init__( def __init__(
self, self,

View file

@ -194,7 +194,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
"role": "user", "role": "user",
"content": "this is a test request, write a short poem" "content": "this is a test request, write a short poem"
} }
]) ], max_tokens=256)
click.echo(f'\nLiteLLM: response from proxy {response}') click.echo(f'\nLiteLLM: response from proxy {response}')
print("\n Making streaming request to proxy") print("\n Making streaming request to proxy")

View file

@ -11,6 +11,7 @@ from datetime import datetime
from typing import Dict, List, Optional, Union, Literal from typing import Dict, List, Optional, Union, Literal
import random, threading, time import random, threading, time
import litellm, openai import litellm, openai
from litellm.caching import RedisCache, InMemoryCache, DualCache
import logging, asyncio import logging, asyncio
import inspect import inspect
from openai import AsyncOpenAI from openai import AsyncOpenAI
@ -46,6 +47,7 @@ class Router:
num_retries: int = 0, num_retries: int = 0,
timeout: float = 600, timeout: float = 600,
default_litellm_params = {}, # default params for Router.chat.completion.create default_litellm_params = {}, # default params for Router.chat.completion.create
set_verbose: bool = False,
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None: routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
if model_list: if model_list:
@ -57,7 +59,7 @@ class Router:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0 self.deployment_latency_map[m["litellm_params"]["model"]] = 0
self.num_retries = num_retries self.num_retries = num_retries
self.set_verbose = set_verbose
self.chat = litellm.Chat(params=default_litellm_params) self.chat = litellm.Chat(params=default_litellm_params)
self.default_litellm_params = default_litellm_params self.default_litellm_params = default_litellm_params
@ -69,6 +71,7 @@ class Router:
self._start_health_check_thread() self._start_health_check_thread()
### CACHING ### ### CACHING ###
redis_cache = None
if redis_host is not None and redis_port is not None and redis_password is not None: if redis_host is not None and redis_port is not None and redis_password is not None:
cache_config = { cache_config = {
'type': 'redis', 'type': 'redis',
@ -76,6 +79,7 @@ class Router:
'port': redis_port, 'port': redis_port,
'password': redis_password 'password': redis_password
} }
redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password)
else: # use an in-memory cache else: # use an in-memory cache
cache_config = { cache_config = {
"type": "local" "type": "local"
@ -83,7 +87,7 @@ class Router:
if cache_responses: if cache_responses:
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
self.cache_responses = cache_responses self.cache_responses = cache_responses
self.cache = litellm.Cache(**cache_config) # use Redis for tracking load balancing self.cache = DualCache(redis_cache=redis_cache) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
## USAGE TRACKING ## ## USAGE TRACKING ##
if isinstance(litellm.success_callback, list): if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.deployment_callback) litellm.success_callback.append(self.deployment_callback)
@ -155,6 +159,10 @@ class Router:
def get_model_names(self): def get_model_names(self):
return self.model_names return self.model_names
def print_verbose(self, print_statement):
if self.set_verbose:
print(f"LiteLLM.Router: {print_statement}") # noqa
def get_available_deployment(self, def get_available_deployment(self,
model: str, model: str,
messages: Optional[List[Dict[str, str]]] = None, messages: Optional[List[Dict[str, str]]] = None,
@ -166,19 +174,18 @@ class Router:
### get all deployments ### get all deployments
### filter out the deployments currently cooling down ### filter out the deployments currently cooling down
healthy_deployments = [m for m in self.model_list if m["model_name"] == model] healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
current_time = time.time()
iter = 0
deployments_to_remove = [] deployments_to_remove = []
cooldown_deployments = self._get_cooldown_deployments() cooldown_deployments = self._get_cooldown_deployments()
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
### FIND UNHEALTHY DEPLOYMENTS ### FIND UNHEALTHY DEPLOYMENTS
for deployment in healthy_deployments: for deployment in healthy_deployments:
deployment_name = deployment["litellm_params"]["model"] deployment_name = deployment["litellm_params"]["model"]
if deployment_name in cooldown_deployments: if deployment_name in cooldown_deployments:
deployments_to_remove.append(deployment) deployments_to_remove.append(deployment)
iter += 1
### FILTER OUT UNHEALTHY DEPLOYMENTS ### FILTER OUT UNHEALTHY DEPLOYMENTS
for deployment in deployments_to_remove: for deployment in deployments_to_remove:
healthy_deployments.remove(deployment) healthy_deployments.remove(deployment)
self.print_verbose(f"healthy deployments: {healthy_deployments}")
if litellm.model_alias_map and model in litellm.model_alias_map: if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[ model = litellm.model_alias_map[
model model
@ -245,42 +252,56 @@ class Router:
def function_with_retries(self, *args, **kwargs): def function_with_retries(self, *args, **kwargs):
# we'll backoff exponentially with each retry # we'll backoff exponentially with each retry
backoff_factor = 1 backoff_factor = 1
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function") original_function = kwargs.pop("original_function")
for current_attempt in range(self.num_retries): num_retries = kwargs.pop("num_retries")
self.num_retries -= 1 # decrement the number of retries for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try: try:
# if the function call is successful, no exception will be raised and we'll break out of the loop # if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs) response = original_function(*args, **kwargs)
return response return response
except openai.RateLimitError as e: except openai.RateLimitError as e:
# on RateLimitError we'll wait for an exponential time before trying again if num_retries > 0:
time.sleep(backoff_factor) # on RateLimitError we'll wait for an exponential time before trying again
time.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
except openai.APIError as e:
# on APIError we immediately retry without any wait, change this if necessary
pass
# increase backoff factor for next run
backoff_factor *= 2
else:
raise e
except Exception as e: except Exception as e:
# for any other exception types, don't retry # for any other exception types, immediately retry
raise e if num_retries > 0:
pass
else:
raise e
num_retries -= 1 # decrement the number of retries
### COMPLETION + EMBEDDING FUNCTIONS ### COMPLETION + EMBEDDING FUNCTIONS
def completion(self, def completion(self,
model: str, model: str,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
**kwargs): **kwargs):
""" """
Example usage: Example usage:
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}] response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
""" """
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
kwargs["num_retries"] = self.num_retries
return self.function_with_retries(**kwargs)
def _completion(
self,
model: str,
messages: List[Dict[str, str]],
**kwargs):
try: try:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages) deployment = self.get_available_deployment(model=model, messages=messages)
@ -288,18 +309,11 @@ class Router:
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in data: # prioritize model-specific params > default router params
data[k] = v data[k] = v
self.print_verbose(f"completion model: {data['model']}")
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
except Exception as e: except Exception as e:
if self.num_retries > 0: raise e
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
raise e
async def acompletion(self, async def acompletion(self,
model: str, model: str,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
@ -427,8 +441,9 @@ class Router:
current_minute = datetime.now().strftime("%H-%M") current_minute = datetime.now().strftime("%H-%M")
# get the current cooldown list for that minute # get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(cache_key=cooldown_key) cached_value = self.cache.get_cache(key=cooldown_key)
self.print_verbose(f"adding {deployment} to cooldown models")
# update value # update value
try: try:
if deployment in cached_value: if deployment in cached_value:
@ -436,12 +451,11 @@ class Router:
else: else:
cached_value = cached_value + [deployment] cached_value = cached_value + [deployment]
# save updated value # save updated value
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60) self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
except: except:
cached_value = [deployment] cached_value = [deployment]
# save updated value # save updated value
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60) self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
def _get_cooldown_deployments(self): def _get_cooldown_deployments(self):
""" """
@ -454,8 +468,9 @@ class Router:
# ---------------------- # ----------------------
# Return cooldown models # Return cooldown models
# ---------------------- # ----------------------
cooldown_models = self.cache.get_cache(cache_key=cooldown_key) or [] cooldown_models = self.cache.get_cache(key=cooldown_key) or []
self.print_verbose(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models return cooldown_models
def get_usage_based_available_deployment(self, def get_usage_based_available_deployment(self,
@ -522,21 +537,21 @@ class Router:
# ------------ # ------------
# Return usage # Return usage
# ------------ # ------------
tpm = self.cache.get_cache(cache_key=tpm_key) or 0 tpm = self.cache.get_cache(key=tpm_key) or 0
rpm = self.cache.get_cache(cache_key=rpm_key) or 0 rpm = self.cache.get_cache(key=rpm_key) or 0
return int(tpm), int(rpm) return int(tpm), int(rpm)
def increment(self, key: str, increment_value: int): def increment(self, key: str, increment_value: int):
# get value # get value
cached_value = self.cache.get_cache(cache_key=key) cached_value = self.cache.get_cache(key=key)
# update value # update value
try: try:
cached_value = cached_value + increment_value cached_value = cached_value + increment_value
except: except:
cached_value = increment_value cached_value = increment_value
# save updated value # save updated value
self.cache.add_cache(result=cached_value, cache_key=key, ttl=self.default_cache_time_seconds) self.cache.set_cache(value=cached_value, key=key, ttl=self.default_cache_time_seconds)
def _set_deployment_usage( def _set_deployment_usage(
self, self,

View file

@ -18,7 +18,7 @@ load_dotenv()
def test_multiple_deployments(): def test_multiple_deployments():
import concurrent, time import concurrent, time
litellm.set_verbose=True litellm.set_verbose=False
futures = {} futures = {}
model_list = [{ # list of model deployments model_list = [{ # list of model deployments
"model_name": "gpt-3.5-turbo", # openai model name "model_name": "gpt-3.5-turbo", # openai model name
@ -58,6 +58,7 @@ def test_multiple_deployments():
redis_password=os.getenv("REDIS_PASSWORD"), redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), redis_port=int(os.getenv("REDIS_PORT")),
routing_strategy="simple-shuffle", routing_strategy="simple-shuffle",
set_verbose=False,
num_retries=1) # type: ignore num_retries=1) # type: ignore
# router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=int(os.getenv("REDIS_PORT"))) # type: ignore # router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=int(os.getenv("REDIS_PORT"))) # type: ignore
kwargs = { kwargs = {
@ -81,12 +82,13 @@ Who among the mentioned figures from Ancient Greece contributed to the domain of
} }
results = [] results = []
for _ in range(2):
print(f"starting!!!")
response = router.completion(**kwargs)
results.append(response)
try:
for _ in range(3):
response = router.completion(**kwargs)
results.append(response)
except Exception as e:
raise e
# print(len(results)) # print(len(results))
# with ThreadPoolExecutor(max_workers=100) as executor: # with ThreadPoolExecutor(max_workers=100) as executor: