forked from phoenix/litellm-mirror
fix(router.py): fix caching for tracking cooldowns + usage
This commit is contained in:
parent
94c1d71b2c
commit
61fc76a8c4
5 changed files with 148 additions and 75 deletions
|
@ -89,6 +89,7 @@ const sidebars = {
|
|||
"routing",
|
||||
"rules",
|
||||
"set_keys",
|
||||
"budget_manager",
|
||||
"completion/token_usage",
|
||||
{
|
||||
type: 'category',
|
||||
|
@ -157,7 +158,6 @@ const sidebars = {
|
|||
label: 'Extras',
|
||||
items: [
|
||||
'extras/contributing',
|
||||
"budget_manager",
|
||||
"proxy_server",
|
||||
{
|
||||
type: "category",
|
||||
|
|
|
@ -23,6 +23,9 @@ def get_prompt(*args, **kwargs):
|
|||
return prompt
|
||||
return None
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
|
||||
class BaseCache:
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
|
@ -32,6 +35,34 @@ class BaseCache:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(self):
|
||||
# if users don't provider one, use the default litellm cache
|
||||
self.cache_dict = {}
|
||||
self.ttl_dict = {}
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
self.cache_dict[key] = value
|
||||
if "ttl" in kwargs:
|
||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if key in self.ttl_dict:
|
||||
if time.time() > self.ttl_dict[key]:
|
||||
self.cache_dict.pop(key, None)
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
if isinstance(cached_response, dict):
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
|
||||
class RedisCache(BaseCache):
|
||||
def __init__(self, host, port, password):
|
||||
import redis
|
||||
|
@ -65,7 +96,58 @@ class RedisCache(BaseCache):
|
|||
traceback.print_exc()
|
||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||
|
||||
class DualCache(BaseCache):
|
||||
"""
|
||||
This updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
|
||||
"""
|
||||
def __init__(self, in_memory_cache: InMemoryCache =None, redis_cache: RedisCache =None) -> None:
|
||||
super().__init__()
|
||||
# If in_memory_cache is not provided, use the default InMemoryCache
|
||||
self.in_memory_cache = in_memory_cache or InMemoryCache()
|
||||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
try:
|
||||
print_verbose(f"set cache: key: {key}; value: {value}")
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
print_verbose(f"get cache: cache key: {key}")
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if self.redis_cache is not None:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
#### DEPRECATED ####
|
||||
class HostedCache(BaseCache):
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
if "ttl" in kwargs:
|
||||
|
@ -91,33 +173,7 @@ class HostedCache(BaseCache):
|
|||
return cached_response
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(self):
|
||||
# if users don't provider one, use the default litellm cache
|
||||
self.cache_dict = {}
|
||||
self.ttl_dict = {}
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
self.cache_dict[key] = value
|
||||
if "ttl" in kwargs:
|
||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if key in self.ttl_dict:
|
||||
if time.time() > self.ttl_dict[key]:
|
||||
self.cache_dict.pop(key, None)
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
|
||||
#### LiteLLM.Completion Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -194,7 +194,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
|||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
], max_tokens=256)
|
||||
click.echo(f'\nLiteLLM: response from proxy {response}')
|
||||
|
||||
print("\n Making streaming request to proxy")
|
||||
|
|
|
@ -11,6 +11,7 @@ from datetime import datetime
|
|||
from typing import Dict, List, Optional, Union, Literal
|
||||
import random, threading, time
|
||||
import litellm, openai
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
import logging, asyncio
|
||||
import inspect
|
||||
from openai import AsyncOpenAI
|
||||
|
@ -46,6 +47,7 @@ class Router:
|
|||
num_retries: int = 0,
|
||||
timeout: float = 600,
|
||||
default_litellm_params = {}, # default params for Router.chat.completion.create
|
||||
set_verbose: bool = False,
|
||||
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
|
||||
|
||||
if model_list:
|
||||
|
@ -57,7 +59,7 @@ class Router:
|
|||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||
|
||||
self.num_retries = num_retries
|
||||
|
||||
self.set_verbose = set_verbose
|
||||
self.chat = litellm.Chat(params=default_litellm_params)
|
||||
|
||||
self.default_litellm_params = default_litellm_params
|
||||
|
@ -69,6 +71,7 @@ class Router:
|
|||
self._start_health_check_thread()
|
||||
|
||||
### CACHING ###
|
||||
redis_cache = None
|
||||
if redis_host is not None and redis_port is not None and redis_password is not None:
|
||||
cache_config = {
|
||||
'type': 'redis',
|
||||
|
@ -76,6 +79,7 @@ class Router:
|
|||
'port': redis_port,
|
||||
'password': redis_password
|
||||
}
|
||||
redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password)
|
||||
else: # use an in-memory cache
|
||||
cache_config = {
|
||||
"type": "local"
|
||||
|
@ -83,7 +87,7 @@ class Router:
|
|||
if cache_responses:
|
||||
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
||||
self.cache_responses = cache_responses
|
||||
self.cache = litellm.Cache(**cache_config) # use Redis for tracking load balancing
|
||||
self.cache = DualCache(redis_cache=redis_cache) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
## USAGE TRACKING ##
|
||||
if isinstance(litellm.success_callback, list):
|
||||
litellm.success_callback.append(self.deployment_callback)
|
||||
|
@ -155,6 +159,10 @@ class Router:
|
|||
def get_model_names(self):
|
||||
return self.model_names
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if self.set_verbose:
|
||||
print(f"LiteLLM.Router: {print_statement}") # noqa
|
||||
|
||||
def get_available_deployment(self,
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
|
@ -166,19 +174,18 @@ class Router:
|
|||
### get all deployments
|
||||
### filter out the deployments currently cooling down
|
||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||
current_time = time.time()
|
||||
iter = 0
|
||||
deployments_to_remove = []
|
||||
cooldown_deployments = self._get_cooldown_deployments()
|
||||
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
||||
### FIND UNHEALTHY DEPLOYMENTS
|
||||
for deployment in healthy_deployments:
|
||||
deployment_name = deployment["litellm_params"]["model"]
|
||||
if deployment_name in cooldown_deployments:
|
||||
deployments_to_remove.append(deployment)
|
||||
iter += 1
|
||||
### FILTER OUT UNHEALTHY DEPLOYMENTS
|
||||
for deployment in deployments_to_remove:
|
||||
healthy_deployments.remove(deployment)
|
||||
self.print_verbose(f"healthy deployments: {healthy_deployments}")
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[
|
||||
model
|
||||
|
@ -245,42 +252,56 @@ class Router:
|
|||
def function_with_retries(self, *args, **kwargs):
|
||||
# we'll backoff exponentially with each retry
|
||||
backoff_factor = 1
|
||||
original_exception = kwargs.pop("original_exception")
|
||||
original_function = kwargs.pop("original_function")
|
||||
for current_attempt in range(self.num_retries):
|
||||
self.num_retries -= 1 # decrement the number of retries
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
for current_attempt in range(num_retries):
|
||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
|
||||
except openai.RateLimitError as e:
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
time.sleep(backoff_factor)
|
||||
|
||||
# increase backoff factor for next run
|
||||
backoff_factor *= 2
|
||||
|
||||
except openai.APIError as e:
|
||||
# on APIError we immediately retry without any wait, change this if necessary
|
||||
pass
|
||||
if num_retries > 0:
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
time.sleep(backoff_factor)
|
||||
|
||||
# increase backoff factor for next run
|
||||
backoff_factor *= 2
|
||||
else:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
# for any other exception types, don't retry
|
||||
raise e
|
||||
# for any other exception types, immediately retry
|
||||
if num_retries > 0:
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
num_retries -= 1 # decrement the number of retries
|
||||
|
||||
### COMPLETION + EMBEDDING FUNCTIONS
|
||||
|
||||
def completion(self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
is_retry: Optional[bool] = False,
|
||||
is_fallback: Optional[bool] = False,
|
||||
**kwargs):
|
||||
"""
|
||||
Example usage:
|
||||
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||
"""
|
||||
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_function"] = self._completion
|
||||
kwargs["num_retries"] = self.num_retries
|
||||
return self.function_with_retries(**kwargs)
|
||||
|
||||
def _completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
**kwargs):
|
||||
|
||||
try:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
|
@ -288,18 +309,11 @@ class Router:
|
|||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
|
||||
self.print_verbose(f"completion model: {data['model']}")
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
except Exception as e:
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.completion
|
||||
return self.function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
raise e
|
||||
async def acompletion(self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
|
@ -427,8 +441,9 @@ class Router:
|
|||
current_minute = datetime.now().strftime("%H-%M")
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||
cached_value = self.cache.get_cache(cache_key=cooldown_key)
|
||||
cached_value = self.cache.get_cache(key=cooldown_key)
|
||||
|
||||
self.print_verbose(f"adding {deployment} to cooldown models")
|
||||
# update value
|
||||
try:
|
||||
if deployment in cached_value:
|
||||
|
@ -436,12 +451,11 @@ class Router:
|
|||
else:
|
||||
cached_value = cached_value + [deployment]
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
|
||||
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
|
||||
except:
|
||||
cached_value = [deployment]
|
||||
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=cooldown_key, ttl=60)
|
||||
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=60)
|
||||
|
||||
def _get_cooldown_deployments(self):
|
||||
"""
|
||||
|
@ -454,8 +468,9 @@ class Router:
|
|||
# ----------------------
|
||||
# Return cooldown models
|
||||
# ----------------------
|
||||
cooldown_models = self.cache.get_cache(cache_key=cooldown_key) or []
|
||||
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
|
||||
|
||||
self.print_verbose(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
def get_usage_based_available_deployment(self,
|
||||
|
@ -522,21 +537,21 @@ class Router:
|
|||
# ------------
|
||||
# Return usage
|
||||
# ------------
|
||||
tpm = self.cache.get_cache(cache_key=tpm_key) or 0
|
||||
rpm = self.cache.get_cache(cache_key=rpm_key) or 0
|
||||
tpm = self.cache.get_cache(key=tpm_key) or 0
|
||||
rpm = self.cache.get_cache(key=rpm_key) or 0
|
||||
|
||||
return int(tpm), int(rpm)
|
||||
|
||||
def increment(self, key: str, increment_value: int):
|
||||
# get value
|
||||
cached_value = self.cache.get_cache(cache_key=key)
|
||||
cached_value = self.cache.get_cache(key=key)
|
||||
# update value
|
||||
try:
|
||||
cached_value = cached_value + increment_value
|
||||
except:
|
||||
cached_value = increment_value
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=key, ttl=self.default_cache_time_seconds)
|
||||
self.cache.set_cache(value=cached_value, key=key, ttl=self.default_cache_time_seconds)
|
||||
|
||||
def _set_deployment_usage(
|
||||
self,
|
||||
|
|
|
@ -18,7 +18,7 @@ load_dotenv()
|
|||
|
||||
def test_multiple_deployments():
|
||||
import concurrent, time
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose=False
|
||||
futures = {}
|
||||
model_list = [{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
|
@ -58,6 +58,7 @@ def test_multiple_deployments():
|
|||
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||
redis_port=int(os.getenv("REDIS_PORT")),
|
||||
routing_strategy="simple-shuffle",
|
||||
set_verbose=False,
|
||||
num_retries=1) # type: ignore
|
||||
# router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=int(os.getenv("REDIS_PORT"))) # type: ignore
|
||||
kwargs = {
|
||||
|
@ -81,12 +82,13 @@ Who among the mentioned figures from Ancient Greece contributed to the domain of
|
|||
}
|
||||
|
||||
results = []
|
||||
|
||||
for _ in range(2):
|
||||
print(f"starting!!!")
|
||||
response = router.completion(**kwargs)
|
||||
results.append(response)
|
||||
|
||||
try:
|
||||
for _ in range(3):
|
||||
response = router.completion(**kwargs)
|
||||
results.append(response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
# print(len(results))
|
||||
# with ThreadPoolExecutor(max_workers=100) as executor:
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue