fix(router.py): fix caching for tracking cooldowns + usage

This commit is contained in:
Krrish Dholakia 2023-11-23 11:13:24 -08:00
parent 94c1d71b2c
commit 61fc76a8c4
5 changed files with 148 additions and 75 deletions

View file

@ -11,6 +11,7 @@ from datetime import datetime
from typing import Dict, List, Optional, Union, Literal
import random, threading, time
import litellm, openai
from litellm.caching import RedisCache, InMemoryCache, DualCache
import logging, asyncio
import inspect
from openai import AsyncOpenAI
@ -46,6 +47,7 @@ class Router:
num_retries: int = 0,
timeout: float = 600,
default_litellm_params = {}, # default params for Router.chat.completion.create
set_verbose: bool = False,
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
if model_list:
@ -57,7 +59,7 @@ class Router:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
self.num_retries = num_retries
self.set_verbose = set_verbose
self.chat = litellm.Chat(params=default_litellm_params)
self.default_litellm_params = default_litellm_params
@ -69,6 +71,7 @@ class Router:
self._start_health_check_thread()
### CACHING ###
redis_cache = None
if redis_host is not None and redis_port is not None and redis_password is not None:
cache_config = {
'type': 'redis',
@ -76,6 +79,7 @@ class Router:
'port': redis_port,
'password': redis_password
}
redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password)
else: # use an in-memory cache
cache_config = {
"type": "local"
@ -83,7 +87,7 @@ class Router:
if cache_responses:
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
self.cache_responses = cache_responses
self.cache = litellm.Cache(**cache_config) # use Redis for tracking load balancing
self.cache = DualCache(redis_cache=redis_cache) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
## USAGE TRACKING ##
if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.deployment_callback)
@ -155,6 +159,10 @@ class Router:
def get_model_names(self):
return self.model_names
def print_verbose(self, print_statement):
if self.set_verbose:
print(f"LiteLLM.Router: {print_statement}") # noqa
def get_available_deployment(self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
@ -166,19 +174,18 @@ class Router:
### get all deployments
### filter out the deployments currently cooling down
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
current_time = time.time()
iter = 0
deployments_to_remove = []
cooldown_deployments = self._get_cooldown_deployments()
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
### FIND UNHEALTHY DEPLOYMENTS
for deployment in healthy_deployments:
deployment_name = deployment["litellm_params"]["model"]
if deployment_name in cooldown_deployments:
deployments_to_remove.append(deployment)
iter += 1
### FILTER OUT UNHEALTHY DEPLOYMENTS
for deployment in deployments_to_remove:
healthy_deployments.remove(deployment)
self.print_verbose(f"healthy deployments: {healthy_deployments}")
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[
model
@ -245,42 +252,56 @@ class Router:
def function_with_retries(self, *args, **kwargs):
# we'll backoff exponentially with each retry
backoff_factor = 1
original_exception = kwargs.pop("original_exception")
original_function = kwargs.pop("original_function")
for current_attempt in range(self.num_retries):
self.num_retries -= 1 # decrement the number of retries
num_retries = kwargs.pop("num_retries")
for current_attempt in range(num_retries):
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
except openai.RateLimitError as e:
# on RateLimitError we'll wait for an exponential time before trying again
time.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
except openai.APIError as e:
# on APIError we immediately retry without any wait, change this if necessary
pass
if num_retries > 0:
# on RateLimitError we'll wait for an exponential time before trying again
time.sleep(backoff_factor)
# increase backoff factor for next run
backoff_factor *= 2
else:
raise e
except Exception as e:
# for any other exception types, don't retry
raise e
# for any other exception types, immediately retry
if num_retries > 0:
pass
else:
raise e
num_retries -= 1 # decrement the number of retries
### COMPLETION + EMBEDDING FUNCTIONS
def completion(self,
model: str,
messages: List[Dict[str, str]],
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
**kwargs):
"""
Example usage:
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
"""
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
kwargs["num_retries"] = self.num_retries
return self.function_with_retries(**kwargs)
def _completion(
self,
model: str,
messages: List[Dict[str, str]],
**kwargs):
try:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
@ -288,18 +309,11 @@ class Router:
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
self.print_verbose(f"completion model: {data['model']}")
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
raise e
raise e
async def acompletion(self,
model: str,
messages: List[Dict[str, str]],
@ -427,8 +441,9 @@ class Router:
current_minute = datetime.now().strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(cache_key=cooldown_key)
cached_value = self.cache.get_cache(key=cooldown_key)
self.print_verbose(f"adding {deployment} to cooldown models")
# update value
try:
if deployment in cached_value:
@ -436,12 +451,11 @@ class Router:
else:
cached_value = cached_value + [deployment]
# save updated value
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
except:
cached_value = [deployment]
# save updated value
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
def _get_cooldown_deployments(self):
"""
@ -454,8 +468,9 @@ class Router:
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = self.cache.get_cache(cache_key=cooldown_key) or []
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
self.print_verbose(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def get_usage_based_available_deployment(self,
@ -522,21 +537,21 @@ class Router:
# ------------
# Return usage
# ------------
tpm = self.cache.get_cache(cache_key=tpm_key) or 0
rpm = self.cache.get_cache(cache_key=rpm_key) or 0
tpm = self.cache.get_cache(key=tpm_key) or 0
rpm = self.cache.get_cache(key=rpm_key) or 0
return int(tpm), int(rpm)
def increment(self, key: str, increment_value: int):
# get value
cached_value = self.cache.get_cache(cache_key=key)
cached_value = self.cache.get_cache(key=key)
# update value
try:
cached_value = cached_value + increment_value
except:
cached_value = increment_value
# save updated value
self.cache.add_cache(result=cached_value, cache_key=key, ttl=self.default_cache_time_seconds)
self.cache.set_cache(value=cached_value, key=key, ttl=self.default_cache_time_seconds)
def _set_deployment_usage(
self,