fix(router.py): completing redis support work for router

This commit is contained in:
Krrish Dholakia 2023-10-18 12:12:50 -07:00
parent 608a27d68b
commit 204218508d
3 changed files with 45 additions and 34 deletions

View file

@ -2,25 +2,6 @@ from typing import Union, List, Dict, Optional
from datetime import datetime
import litellm
class Cache:
"""
Underlying dictionary for Router. This can either be a dictionary or a Redis Cache (if credentials are set).
"""
def __init__(self, cache_config: dict) -> None:
self.cache_config = cache_config
if cache_config["type"] == "redis":
pass
elif cache_config["type"] == "local":
self.usage_dict: Dict = {}
def get(self, key: str):
return self.usage_dict.get(key, 0)
def increment(self, key: str, increment_value: int, expiry: int):
if self.cache_config["type"] == "redis":
pass
elif self.cache_config["type"] == "local":
self.usage_dict[key] = self.usage_dict.get(key, 0) + increment_value
class Router:
"""
Example usage:
@ -52,12 +33,12 @@ class Router:
'port': redis_port,
'password': redis_password
}
else:
else: # use an in-memory cache
cache_config = {
"type": "local"
}
self.cache = Cache(cache_config)
litellm.cache = litellm.Cache(**cache_config)
self.cache = litellm.Cache(cache_config) # use Redis for tracking load balancing
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
litellm.success_callback = [self.deployment_callback]
def completion(self,
@ -182,10 +163,24 @@ class Router:
# ------------
# Return usage
# ------------
tpm = self.cache.get(tpm_key)
rpm = self.cache.get(rpm_key)
tpm = self.cache.get_cache(tpm_key)
rpm = self.cache.get_cache(rpm_key)
if tpm is None:
tpm = 0
if rpm is None:
rpm = 0
return int(tpm), int(rpm)
def increment(self, key: str, increment_value: int):
# get value
cached_value = self.cache.get_cache(key)
# update value
cached_value = cached_value + increment_value
# save updated value
self.cache.add_cache(result=cached_value, cache_key=key)
def _set_deployment_usage(
self,
model_name: str,
@ -195,12 +190,11 @@ class Router:
# Setup values
# ------------
current_minute = datetime.now().strftime("%H-%M")
ttl = 120 # 2 minutes
tpm_key = f'{model_name}:tpm:{current_minute}'
rpm_key = f'{model_name}:rpm:{current_minute}'
# ------------
# Update usage
# ------------
self.cache.increment(tpm_key, total_tokens, expiry=ttl)
self.cache.increment(rpm_key, 1, expiry=ttl)
self.cache.increment(tpm_key, total_tokens)
self.cache.increment(rpm_key, 1)