fix(router.py): completing redis support work for router

This commit is contained in:
Krrish Dholakia 2023-10-18 12:12:50 -07:00
parent 608a27d68b
commit 204218508d
3 changed files with 45 additions and 34 deletions

View file

@ -158,6 +158,9 @@ class Cache():
The cached result if it exists, otherwise None.
"""
try: # never block execution
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
cached_result = self.cache.get_cache(cache_key)
@ -182,6 +185,9 @@ class Cache():
None
"""
try:
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
# print("adding to cache", cache_key, result)
# print(cache_key)

View file

@ -2,25 +2,6 @@ from typing import Union, List, Dict, Optional
from datetime import datetime
import litellm
class Cache:
"""
Underlying dictionary for Router. This can either be a dictionary or a Redis Cache (if credentials are set).
"""
def __init__(self, cache_config: dict) -> None:
self.cache_config = cache_config
if cache_config["type"] == "redis":
pass
elif cache_config["type"] == "local":
self.usage_dict: Dict = {}
def get(self, key: str):
return self.usage_dict.get(key, 0)
def increment(self, key: str, increment_value: int, expiry: int):
if self.cache_config["type"] == "redis":
pass
elif self.cache_config["type"] == "local":
self.usage_dict[key] = self.usage_dict.get(key, 0) + increment_value
class Router:
"""
Example usage:
@ -52,12 +33,12 @@ class Router:
'port': redis_port,
'password': redis_password
}
else:
else: # use an in-memory cache
cache_config = {
"type": "local"
}
self.cache = Cache(cache_config)
litellm.cache = litellm.Cache(**cache_config)
self.cache = litellm.Cache(cache_config) # use Redis for tracking load balancing
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
litellm.success_callback = [self.deployment_callback]
def completion(self,
@ -182,10 +163,24 @@ class Router:
# ------------
# Return usage
# ------------
tpm = self.cache.get(tpm_key)
rpm = self.cache.get(rpm_key)
tpm = self.cache.get_cache(tpm_key)
rpm = self.cache.get_cache(rpm_key)
if tpm is None:
tpm = 0
if rpm is None:
rpm = 0
return int(tpm), int(rpm)
def increment(self, key: str, increment_value: int):
# get value
cached_value = self.cache.get_cache(key)
# update value
cached_value = cached_value + increment_value
# save updated value
self.cache.add_cache(result=cached_value, cache_key=key)
def _set_deployment_usage(
self,
model_name: str,
@ -195,12 +190,11 @@ class Router:
# Setup values
# ------------
current_minute = datetime.now().strftime("%H-%M")
ttl = 120 # 2 minutes
tpm_key = f'{model_name}:tpm:{current_minute}'
rpm_key = f'{model_name}:rpm:{current_minute}'
# ------------
# Update usage
# ------------
self.cache.increment(tpm_key, total_tokens, expiry=ttl)
self.cache.increment(rpm_key, 1, expiry=ttl)
self.cache.increment(tpm_key, total_tokens)
self.cache.increment(rpm_key, 1)

View file

@ -8,11 +8,12 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm import Router
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
load_dotenv()
model_list = [{
model_list = [{ # list of model deployments
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
@ -42,9 +43,19 @@ model_list = [{
"rpm": 9000
}]
router = Router(model_list=model_list)
router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=os.getenv("REDIS_PORT"))
# openai.ChatCompletion.create replacement
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}])
completions = []
with ThreadPoolExecutor(max_workers=100) as executor:
kwargs = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
}
for _ in range(20):
future = executor.submit(router.completion, **kwargs)
completions.append(future)
print(response)
# Retrieve the results from the futures
results = [future.result() for future in completions]
print(results)