mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(router.py): completing redis support work for router
This commit is contained in:
parent
608a27d68b
commit
204218508d
3 changed files with 45 additions and 34 deletions
|
@ -158,6 +158,9 @@ class Cache():
|
|||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
if cache_key is not None:
|
||||
cached_result = self.cache.get_cache(cache_key)
|
||||
|
@ -182,6 +185,9 @@ class Cache():
|
|||
None
|
||||
"""
|
||||
try:
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
# print("adding to cache", cache_key, result)
|
||||
# print(cache_key)
|
||||
|
|
|
@ -2,25 +2,6 @@ from typing import Union, List, Dict, Optional
|
|||
from datetime import datetime
|
||||
import litellm
|
||||
|
||||
class Cache:
|
||||
"""
|
||||
Underlying dictionary for Router. This can either be a dictionary or a Redis Cache (if credentials are set).
|
||||
"""
|
||||
def __init__(self, cache_config: dict) -> None:
|
||||
self.cache_config = cache_config
|
||||
if cache_config["type"] == "redis":
|
||||
pass
|
||||
elif cache_config["type"] == "local":
|
||||
self.usage_dict: Dict = {}
|
||||
def get(self, key: str):
|
||||
return self.usage_dict.get(key, 0)
|
||||
|
||||
def increment(self, key: str, increment_value: int, expiry: int):
|
||||
if self.cache_config["type"] == "redis":
|
||||
pass
|
||||
elif self.cache_config["type"] == "local":
|
||||
self.usage_dict[key] = self.usage_dict.get(key, 0) + increment_value
|
||||
|
||||
class Router:
|
||||
"""
|
||||
Example usage:
|
||||
|
@ -52,12 +33,12 @@ class Router:
|
|||
'port': redis_port,
|
||||
'password': redis_password
|
||||
}
|
||||
else:
|
||||
else: # use an in-memory cache
|
||||
cache_config = {
|
||||
"type": "local"
|
||||
}
|
||||
self.cache = Cache(cache_config)
|
||||
litellm.cache = litellm.Cache(**cache_config)
|
||||
self.cache = litellm.Cache(cache_config) # use Redis for tracking load balancing
|
||||
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
||||
litellm.success_callback = [self.deployment_callback]
|
||||
|
||||
def completion(self,
|
||||
|
@ -182,10 +163,24 @@ class Router:
|
|||
# ------------
|
||||
# Return usage
|
||||
# ------------
|
||||
tpm = self.cache.get(tpm_key)
|
||||
rpm = self.cache.get(rpm_key)
|
||||
tpm = self.cache.get_cache(tpm_key)
|
||||
rpm = self.cache.get_cache(rpm_key)
|
||||
|
||||
if tpm is None:
|
||||
tpm = 0
|
||||
if rpm is None:
|
||||
rpm = 0
|
||||
|
||||
return int(tpm), int(rpm)
|
||||
|
||||
def increment(self, key: str, increment_value: int):
|
||||
# get value
|
||||
cached_value = self.cache.get_cache(key)
|
||||
# update value
|
||||
cached_value = cached_value + increment_value
|
||||
# save updated value
|
||||
self.cache.add_cache(result=cached_value, cache_key=key)
|
||||
|
||||
def _set_deployment_usage(
|
||||
self,
|
||||
model_name: str,
|
||||
|
@ -195,12 +190,11 @@ class Router:
|
|||
# Setup values
|
||||
# ------------
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
ttl = 120 # 2 minutes
|
||||
tpm_key = f'{model_name}:tpm:{current_minute}'
|
||||
rpm_key = f'{model_name}:rpm:{current_minute}'
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
self.cache.increment(tpm_key, total_tokens, expiry=ttl)
|
||||
self.cache.increment(rpm_key, 1, expiry=ttl)
|
||||
self.cache.increment(tpm_key, total_tokens)
|
||||
self.cache.increment(rpm_key, 1)
|
||||
|
|
|
@ -8,11 +8,12 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm import Router
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
model_list = [{
|
||||
model_list = [{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
|
@ -42,9 +43,19 @@ model_list = [{
|
|||
"rpm": 9000
|
||||
}]
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=os.getenv("REDIS_PORT"))
|
||||
|
||||
# openai.ChatCompletion.create replacement
|
||||
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}])
|
||||
completions = []
|
||||
with ThreadPoolExecutor(max_workers=100) as executor:
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
}
|
||||
for _ in range(20):
|
||||
future = executor.submit(router.completion, **kwargs)
|
||||
completions.append(future)
|
||||
|
||||
print(response)
|
||||
# Retrieve the results from the futures
|
||||
results = [future.result() for future in completions]
|
||||
|
||||
print(results)
|
Loading…
Add table
Add a link
Reference in a new issue