mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(router.py): fix caching for tracking cooldowns + usage
This commit is contained in:
parent
94c1d71b2c
commit
61fc76a8c4
5 changed files with 148 additions and 75 deletions
|
@ -11,6 +11,7 @@ from datetime import datetime
|
|||
from typing import Dict, List, Optional, Union, Literal
|
||||
import random, threading, time
|
||||
import litellm, openai
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
import logging, asyncio
|
||||
import inspect
|
||||
from openai import AsyncOpenAI
|
||||
|
@ -46,6 +47,7 @@ class Router:
|
|||
num_retries: int = 0,
|
||||
timeout: float = 600,
|
||||
default_litellm_params = {}, # default params for Router.chat.completion.create
|
||||
set_verbose: bool = False,
|
||||
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
|
||||
|
||||
if model_list:
|
||||
|
@ -57,7 +59,7 @@ class Router:
|
|||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||
|
||||
self.num_retries = num_retries
|
||||
|
||||
self.set_verbose = set_verbose
|
||||
self.chat = litellm.Chat(params=default_litellm_params)
|
||||
|
||||
self.default_litellm_params = default_litellm_params
|
||||
|
@ -69,6 +71,7 @@ class Router:
|
|||
self._start_health_check_thread()
|
||||
|
||||
### CACHING ###
|
||||
redis_cache = None
|
||||
if redis_host is not None and redis_port is not None and redis_password is not None:
|
||||
cache_config = {
|
||||
'type': 'redis',
|
||||
|
@ -76,6 +79,7 @@ class Router:
|
|||
'port': redis_port,
|
||||
'password': redis_password
|
||||
}
|
||||
redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password)
|
||||
else: # use an in-memory cache
|
||||
cache_config = {
|
||||
"type": "local"
|
||||
|
@ -83,7 +87,7 @@ class Router:
|
|||
if cache_responses:
|
||||
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
||||
self.cache_responses = cache_responses
|
||||
self.cache = litellm.Cache(**cache_config) # use Redis for tracking load balancing
|
||||
self.cache = DualCache(redis_cache=redis_cache) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
## USAGE TRACKING ##
|
||||
if isinstance(litellm.success_callback, list):
|
||||
litellm.success_callback.append(self.deployment_callback)
|
||||
|
@ -155,6 +159,10 @@ class Router:
|
|||
def get_model_names(self):
|
||||
return self.model_names
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if self.set_verbose:
|
||||
print(f"LiteLLM.Router: {print_statement}") # noqa
|
||||
|
||||
def get_available_deployment(self,
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
|
@ -166,19 +174,18 @@ class Router:
|
|||
### get all deployments
|
||||
### filter out the deployments currently cooling down
|
||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||
current_time = time.time()
|
||||
iter = 0
|
||||
deployments_to_remove = []
|
||||
cooldown_deployments = self._get_cooldown_deployments()
|
||||
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
||||
### FIND UNHEALTHY DEPLOYMENTS
|
||||
for deployment in healthy_deployments:
|
||||
deployment_name = deployment["litellm_params"]["model"]
|
||||
if deployment_name in cooldown_deployments:
|
||||
deployments_to_remove.append(deployment)
|
||||
iter += 1
|
||||
### FILTER OUT UNHEALTHY DEPLOYMENTS
|
||||
for deployment in deployments_to_remove:
|
||||
healthy_deployments.remove(deployment)
|
||||
self.print_verbose(f"healthy deployments: {healthy_deployments}")
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[
|
||||
model
|
||||
|
@ -245,42 +252,56 @@ class Router:
|
|||
def function_with_retries(self, *args, **kwargs):
|
||||
# we'll backoff exponentially with each retry
|
||||
backoff_factor = 1
|
||||
original_exception = kwargs.pop("original_exception")
|
||||
original_function = kwargs.pop("original_function")
|
||||
for current_attempt in range(self.num_retries):
|
||||
self.num_retries -= 1 # decrement the number of retries
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
for current_attempt in range(num_retries):
|
||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
|
||||
except openai.RateLimitError as e:
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
time.sleep(backoff_factor)
|
||||
|
||||
# increase backoff factor for next run
|
||||
backoff_factor *= 2
|
||||
|
||||
except openai.APIError as e:
|
||||
# on APIError we immediately retry without any wait, change this if necessary
|
||||
pass
|
||||
if num_retries > 0:
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
time.sleep(backoff_factor)
|
||||
|
||||
# increase backoff factor for next run
|
||||
backoff_factor *= 2
|
||||
else:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
# for any other exception types, don't retry
|
||||
raise e
|
||||
# for any other exception types, immediately retry
|
||||
if num_retries > 0:
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
num_retries -= 1 # decrement the number of retries
|
||||
|
||||
### COMPLETION + EMBEDDING FUNCTIONS
|
||||
|
||||
def completion(self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
is_retry: Optional[bool] = False,
|
||||
is_fallback: Optional[bool] = False,
|
||||
**kwargs):
|
||||
"""
|
||||
Example usage:
|
||||
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||
"""
|
||||
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_function"] = self._completion
|
||||
kwargs["num_retries"] = self.num_retries
|
||||
return self.function_with_retries(**kwargs)
|
||||
|
||||
def _completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
**kwargs):
|
||||
|
||||
try:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
|
@ -288,18 +309,11 @@ class Router:
|
|||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
|
||||
self.print_verbose(f"completion model: {data['model']}")
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
except Exception as e:
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.completion
|
||||
return self.function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
raise e
|
||||
async def acompletion(self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
|
@ -427,8 +441,9 @@ class Router:
|
|||
current_minute = datetime.now().strftime("%H-%M")
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||
cached_value = self.cache.get_cache(cache_key=cooldown_key)
|
||||
cached_value = self.cache.get_cache(key=cooldown_key)
|
||||
|
||||
self.print_verbose(f"adding {deployment} to cooldown models")
|
||||
# update value
|
||||
try:
|
||||
if deployment in cached_value:
|
||||
|
@ -436,12 +451,11 @@ class Router:
|
|||
else:
|
||||
cached_value = cached_value + [deployment]
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
|
||||
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
|
||||
except:
|
||||
cached_value = [deployment]
|
||||
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
|
||||
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
|
||||
|
||||
def _get_cooldown_deployments(self):
|
||||
"""
|
||||
|
@ -454,8 +468,9 @@ class Router:
|
|||
# ----------------------
|
||||
# Return cooldown models
|
||||
# ----------------------
|
||||
cooldown_models = self.cache.get_cache(cache_key=cooldown_key) or []
|
||||
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
|
||||
|
||||
self.print_verbose(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
def get_usage_based_available_deployment(self,
|
||||
|
@ -522,21 +537,21 @@ class Router:
|
|||
# ------------
|
||||
# Return usage
|
||||
# ------------
|
||||
tpm = self.cache.get_cache(cache_key=tpm_key) or 0
|
||||
rpm = self.cache.get_cache(cache_key=rpm_key) or 0
|
||||
tpm = self.cache.get_cache(key=tpm_key) or 0
|
||||
rpm = self.cache.get_cache(key=rpm_key) or 0
|
||||
|
||||
return int(tpm), int(rpm)
|
||||
|
||||
def increment(self, key: str, increment_value: int):
|
||||
# get value
|
||||
cached_value = self.cache.get_cache(cache_key=key)
|
||||
cached_value = self.cache.get_cache(key=key)
|
||||
# update value
|
||||
try:
|
||||
cached_value = cached_value + increment_value
|
||||
except:
|
||||
cached_value = increment_value
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=key, ttl=self.default_cache_time_seconds)
|
||||
self.cache.set_cache(value=cached_value, key=key, ttl=self.default_cache_time_seconds)
|
||||
|
||||
def _set_deployment_usage(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue