mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(router.py): completing redis support work for router
This commit is contained in:
parent
608a27d68b
commit
204218508d
3 changed files with 45 additions and 34 deletions
|
@ -2,25 +2,6 @@ from typing import Union, List, Dict, Optional
|
|||
from datetime import datetime
|
||||
import litellm
|
||||
|
||||
class Cache:
|
||||
"""
|
||||
Underlying dictionary for Router. This can either be a dictionary or a Redis Cache (if credentials are set).
|
||||
"""
|
||||
def __init__(self, cache_config: dict) -> None:
|
||||
self.cache_config = cache_config
|
||||
if cache_config["type"] == "redis":
|
||||
pass
|
||||
elif cache_config["type"] == "local":
|
||||
self.usage_dict: Dict = {}
|
||||
def get(self, key: str):
|
||||
return self.usage_dict.get(key, 0)
|
||||
|
||||
def increment(self, key: str, increment_value: int, expiry: int):
|
||||
if self.cache_config["type"] == "redis":
|
||||
pass
|
||||
elif self.cache_config["type"] == "local":
|
||||
self.usage_dict[key] = self.usage_dict.get(key, 0) + increment_value
|
||||
|
||||
class Router:
|
||||
"""
|
||||
Example usage:
|
||||
|
@ -52,12 +33,12 @@ class Router:
|
|||
'port': redis_port,
|
||||
'password': redis_password
|
||||
}
|
||||
else:
|
||||
else: # use an in-memory cache
|
||||
cache_config = {
|
||||
"type": "local"
|
||||
}
|
||||
self.cache = Cache(cache_config)
|
||||
litellm.cache = litellm.Cache(**cache_config)
|
||||
self.cache = litellm.Cache(cache_config) # use Redis for tracking load balancing
|
||||
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
||||
litellm.success_callback = [self.deployment_callback]
|
||||
|
||||
def completion(self,
|
||||
|
@ -182,10 +163,24 @@ class Router:
|
|||
# ------------
|
||||
# Return usage
|
||||
# ------------
|
||||
tpm = self.cache.get(tpm_key)
|
||||
rpm = self.cache.get(rpm_key)
|
||||
tpm = self.cache.get_cache(tpm_key)
|
||||
rpm = self.cache.get_cache(rpm_key)
|
||||
|
||||
if tpm is None:
|
||||
tpm = 0
|
||||
if rpm is None:
|
||||
rpm = 0
|
||||
|
||||
return int(tpm), int(rpm)
|
||||
|
||||
def increment(self, key: str, increment_value: int):
|
||||
# get value
|
||||
cached_value = self.cache.get_cache(key)
|
||||
# update value
|
||||
cached_value = cached_value + increment_value
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=key)
|
||||
|
||||
def _set_deployment_usage(
|
||||
self,
|
||||
model_name: str,
|
||||
|
@ -195,12 +190,11 @@ class Router:
|
|||
# Setup values
|
||||
# ------------
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
ttl = 120 # 2 minutes
|
||||
tpm_key = f'{model_name}:tpm:{current_minute}'
|
||||
rpm_key = f'{model_name}:rpm:{current_minute}'
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
self.cache.increment(tpm_key, total_tokens, expiry=ttl)
|
||||
self.cache.increment(rpm_key, 1, expiry=ttl)
|
||||
self.cache.increment(tpm_key, total_tokens)
|
||||
self.cache.increment(rpm_key, 1)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue