fix(router.py): fix caching for tracking cooldowns + usage

This commit is contained in:
Krrish Dholakia 2023-11-23 11:13:24 -08:00
parent 94c1d71b2c
commit 61fc76a8c4
5 changed files with 148 additions and 75 deletions

View file

@ -89,6 +89,7 @@ const sidebars = {
"routing",
"rules",
"set_keys",
"budget_manager",
"completion/token_usage",
{
type: 'category',
@ -157,7 +158,6 @@ const sidebars = {
label: 'Extras',
items: [
'extras/contributing',
"budget_manager",
"proxy_server",
{
type: "category",

View file

@ -23,6 +23,9 @@ def get_prompt(*args, **kwargs):
return prompt
return None
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
class BaseCache:
def set_cache(self, key, value, **kwargs):
@ -32,6 +35,34 @@ class BaseCache:
raise NotImplementedError
class InMemoryCache(BaseCache):
def __init__(self):
# if users don't provider one, use the default litellm cache
self.cache_dict = {}
self.ttl_dict = {}
def set_cache(self, key, value, **kwargs):
self.cache_dict[key] = value
if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
if isinstance(cached_response, dict):
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
return None
class RedisCache(BaseCache):
def __init__(self, host, port, password):
import redis
@ -65,7 +96,58 @@ class RedisCache(BaseCache):
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
class DualCache(BaseCache):
"""
This updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
"""
def __init__(self, in_memory_cache: InMemoryCache =None, redis_cache: RedisCache =None) -> None:
super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache()
# If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache
def set_cache(self, key, value, **kwargs):
# Update both Redis and in-memory cache
try:
print_verbose(f"set cache: key: {key}; value: {value}")
if self.in_memory_cache is not None:
self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None:
self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e:
print_verbose(e)
def get_cache(self, key, **kwargs):
# Try to fetch from in-memory cache first
try:
print_verbose(f"get cache: cache key: {key}")
result = None
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
if in_memory_result is not None:
result = in_memory_result
if self.redis_cache is not None:
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs)
if redis_result is not None:
# Update in-memory cache with the value from Redis
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
result = redis_result
print_verbose(f"get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
#### DEPRECATED ####
class HostedCache(BaseCache):
def set_cache(self, key, value, **kwargs):
if "ttl" in kwargs:
@ -91,33 +173,7 @@ class HostedCache(BaseCache):
return cached_response
class InMemoryCache(BaseCache):
def __init__(self):
# if users don't provider one, use the default litellm cache
self.cache_dict = {}
self.ttl_dict = {}
def set_cache(self, key, value, **kwargs):
self.cache_dict[key] = value
if "ttl" in kwargs:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
return None
#### LiteLLM.Completion Cache ####
class Cache:
def __init__(
self,

View file

@ -194,7 +194,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
"role": "user",
"content": "this is a test request, write a short poem"
}
])
], max_tokens=256)
click.echo(f'\nLiteLLM: response from proxy {response}')
print("\n Making streaming request to proxy")

View file

@ -11,6 +11,7 @@ from datetime import datetime
from typing import Dict, List, Optional, Union, Literal
import random, threading, time
import litellm, openai
from litellm.caching import RedisCache, InMemoryCache, DualCache
import logging, asyncio
import inspect
from openai import AsyncOpenAI
@ -46,6 +47,7 @@ class Router:
num_retries: int = 0,
timeout: float = 600,
default_litellm_params = {}, # default params for Router.chat.completion.create
set_verbose: bool = False,
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
if model_list:
@ -57,7 +59,7 @@ class Router:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
self.num_retries = num_retries
self.set_verbose = set_verbose
self.chat = litellm.Chat(params=default_litellm_params)
self.default_litellm_params = default_litellm_params
@ -69,6 +71,7 @@ class Router:
self._start_health_check_thread()
### CACHING ###
redis_cache = None
if redis_host is not None and redis_port is not None and redis_password is not None:
cache_config = {
'type': 'redis',
@ -76,6 +79,7 @@ class Router:
'port': redis_port,
'password': redis_password
}
redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password)
else: # use an in-memory cache
cache_config = {
"type": "local"
@ -83,7 +87,7 @@ class Router:
if cache_responses:
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
self.cache_responses = cache_responses
self.cache = litellm.Cache(**cache_config) # use Redis for tracking load balancing
self.cache = DualCache(redis_cache=redis_cache) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
## USAGE TRACKING ##
if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.deployment_callback)
@ -155,6 +159,10 @@ class Router:
def get_model_names(self):
return self.model_names
def print_verbose(self, print_statement):
if self.set_verbose:
print(f"LiteLLM.Router: {print_statement}") # noqa
def get_available_deployment(self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
@ -166,19 +174,18 @@ class Router:
### get all deployments
### filter out the deployments currently cooling down
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
current_time = time.time()
iter = 0
deployments_to_remove = []
cooldown_deployments = self._get_cooldown_deployments()
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
### FIND UNHEALTHY DEPLOYMENTS
for deployment in healthy_deployments:
deployment_name = deployment["litellm_params"]["model"]
if deployment_name in cooldown_deployments:
deployments_to_remove.append(deployment)
iter += 1
### FILTER OUT UNHEALTHY DEPLOYMENTS
for deployment in deployments_to_remove:
healthy_deployments.remove(deployment)
self.print_verbose(f"healthy deployments: {healthy_deployments}")
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[
model
@ -245,42 +252,56 @@ class Router:
def function_with_retries(self, *args, **kwargs):
# we'll backoff exponentially with each retry
backoff_factor = 1
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function")
for current_attempt in range(self.num_retries):
self.num_retries -= 1 # decrement the number of retries
num_retries = kwargs.pop("num_retries")
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
except openai.RateLimitError as e:
if num_retries > 0:
# on RateLimitError we'll wait for an exponential time before trying again
time.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
except openai.APIError as e:
# on APIError we immediately retry without any wait, change this if necessary
pass
else:
raise e
except Exception as e:
# for any other exception types, don't retry
# for any other exception types, immediately retry
if num_retries > 0:
pass
else:
raise e
num_retries -= 1 # decrement the number of retries
### COMPLETION + EMBEDDING FUNCTIONS
def completion(self,
model: str,
messages: List[Dict[str, str]],
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
**kwargs):
"""
Example usage:
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
"""
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
kwargs["num_retries"] = self.num_retries
return self.function_with_retries(**kwargs)
def _completion(
self,
model: str,
messages: List[Dict[str, str]],
**kwargs):
try:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
@ -288,18 +309,11 @@ class Router:
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
self.print_verbose(f"completion model: {data['model']}")
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
raise e
async def acompletion(self,
model: str,
messages: List[Dict[str, str]],
@ -427,8 +441,9 @@ class Router:
current_minute = datetime.now().strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(cache_key=cooldown_key)
cached_value = self.cache.get_cache(key=cooldown_key)
self.print_verbose(f"adding {deployment} to cooldown models")
# update value
try:
if deployment in cached_value:
@ -436,12 +451,11 @@ class Router:
else:
cached_value = cached_value + [deployment]
# save updated value
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
except:
cached_value = [deployment]
# save updated value
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
def _get_cooldown_deployments(self):
"""
@ -454,8 +468,9 @@ class Router:
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = self.cache.get_cache(cache_key=cooldown_key) or []
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
self.print_verbose(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def get_usage_based_available_deployment(self,
@ -522,21 +537,21 @@ class Router:
# ------------
# Return usage
# ------------
tpm = self.cache.get_cache(cache_key=tpm_key) or 0
rpm = self.cache.get_cache(cache_key=rpm_key) or 0
tpm = self.cache.get_cache(key=tpm_key) or 0
rpm = self.cache.get_cache(key=rpm_key) or 0
return int(tpm), int(rpm)
def increment(self, key: str, increment_value: int):
# get value
cached_value = self.cache.get_cache(cache_key=key)
cached_value = self.cache.get_cache(key=key)
# update value
try:
cached_value = cached_value + increment_value
except:
cached_value = increment_value
# save updated value
self.cache.add_cache(result=cached_value, cache_key=key, ttl=self.default_cache_time_seconds)
self.cache.set_cache(value=cached_value, key=key, ttl=self.default_cache_time_seconds)
def _set_deployment_usage(
self,

View file

@ -18,7 +18,7 @@ load_dotenv()
def test_multiple_deployments():
import concurrent, time
litellm.set_verbose=True
litellm.set_verbose=False
futures = {}
model_list = [{ # list of model deployments
"model_name": "gpt-3.5-turbo", # openai model name
@ -58,6 +58,7 @@ def test_multiple_deployments():
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")),
routing_strategy="simple-shuffle",
set_verbose=False,
num_retries=1) # type: ignore
# router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=int(os.getenv("REDIS_PORT"))) # type: ignore
kwargs = {
@ -82,11 +83,12 @@ Who among the mentioned figures from Ancient Greece contributed to the domain of
results = []
for _ in range(2):
print(f"starting!!!")
try:
for _ in range(3):
response = router.completion(**kwargs)
results.append(response)
except Exception as e:
raise e
# print(len(results))
# with ThreadPoolExecutor(max_workers=100) as executor: