mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(router.py): completing redis support work for router
This commit is contained in:
parent
608a27d68b
commit
204218508d
3 changed files with 45 additions and 34 deletions
|
@ -158,7 +158,10 @@ class Cache():
|
||||||
The cached result if it exists, otherwise None.
|
The cached result if it exists, otherwise None.
|
||||||
"""
|
"""
|
||||||
try: # never block execution
|
try: # never block execution
|
||||||
cache_key = self.get_cache_key(*args, **kwargs)
|
if "cache_key" in kwargs:
|
||||||
|
cache_key = kwargs["cache_key"]
|
||||||
|
else:
|
||||||
|
cache_key = self.get_cache_key(*args, **kwargs)
|
||||||
if cache_key is not None:
|
if cache_key is not None:
|
||||||
cached_result = self.cache.get_cache(cache_key)
|
cached_result = self.cache.get_cache(cache_key)
|
||||||
if cached_result != None and 'stream' in kwargs and kwargs['stream'] == True:
|
if cached_result != None and 'stream' in kwargs and kwargs['stream'] == True:
|
||||||
|
@ -182,7 +185,10 @@ class Cache():
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
cache_key = self.get_cache_key(*args, **kwargs)
|
if "cache_key" in kwargs:
|
||||||
|
cache_key = kwargs["cache_key"]
|
||||||
|
else:
|
||||||
|
cache_key = self.get_cache_key(*args, **kwargs)
|
||||||
# print("adding to cache", cache_key, result)
|
# print("adding to cache", cache_key, result)
|
||||||
# print(cache_key)
|
# print(cache_key)
|
||||||
if cache_key is not None:
|
if cache_key is not None:
|
||||||
|
|
|
@ -2,25 +2,6 @@ from typing import Union, List, Dict, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
class Cache:
|
|
||||||
"""
|
|
||||||
Underlying dictionary for Router. This can either be a dictionary or a Redis Cache (if credentials are set).
|
|
||||||
"""
|
|
||||||
def __init__(self, cache_config: dict) -> None:
|
|
||||||
self.cache_config = cache_config
|
|
||||||
if cache_config["type"] == "redis":
|
|
||||||
pass
|
|
||||||
elif cache_config["type"] == "local":
|
|
||||||
self.usage_dict: Dict = {}
|
|
||||||
def get(self, key: str):
|
|
||||||
return self.usage_dict.get(key, 0)
|
|
||||||
|
|
||||||
def increment(self, key: str, increment_value: int, expiry: int):
|
|
||||||
if self.cache_config["type"] == "redis":
|
|
||||||
pass
|
|
||||||
elif self.cache_config["type"] == "local":
|
|
||||||
self.usage_dict[key] = self.usage_dict.get(key, 0) + increment_value
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
"""
|
"""
|
||||||
Example usage:
|
Example usage:
|
||||||
|
@ -52,12 +33,12 @@ class Router:
|
||||||
'port': redis_port,
|
'port': redis_port,
|
||||||
'password': redis_password
|
'password': redis_password
|
||||||
}
|
}
|
||||||
else:
|
else: # use an in-memory cache
|
||||||
cache_config = {
|
cache_config = {
|
||||||
"type": "local"
|
"type": "local"
|
||||||
}
|
}
|
||||||
self.cache = Cache(cache_config)
|
self.cache = litellm.Cache(cache_config) # use Redis for tracking load balancing
|
||||||
litellm.cache = litellm.Cache(**cache_config)
|
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
||||||
litellm.success_callback = [self.deployment_callback]
|
litellm.success_callback = [self.deployment_callback]
|
||||||
|
|
||||||
def completion(self,
|
def completion(self,
|
||||||
|
@ -182,10 +163,24 @@ class Router:
|
||||||
# ------------
|
# ------------
|
||||||
# Return usage
|
# Return usage
|
||||||
# ------------
|
# ------------
|
||||||
tpm = self.cache.get(tpm_key)
|
tpm = self.cache.get_cache(tpm_key)
|
||||||
rpm = self.cache.get(rpm_key)
|
rpm = self.cache.get_cache(rpm_key)
|
||||||
|
|
||||||
|
if tpm is None:
|
||||||
|
tpm = 0
|
||||||
|
if rpm is None:
|
||||||
|
rpm = 0
|
||||||
|
|
||||||
return int(tpm), int(rpm)
|
return int(tpm), int(rpm)
|
||||||
|
|
||||||
|
def increment(self, key: str, increment_value: int):
|
||||||
|
# get value
|
||||||
|
cached_value = self.cache.get_cache(key)
|
||||||
|
# update value
|
||||||
|
cached_value = cached_value + increment_value
|
||||||
|
# save updated value
|
||||||
|
self.cache.add_cache(result=cached_value, cache_key=key)
|
||||||
|
|
||||||
def _set_deployment_usage(
|
def _set_deployment_usage(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
@ -195,12 +190,11 @@ class Router:
|
||||||
# Setup values
|
# Setup values
|
||||||
# ------------
|
# ------------
|
||||||
current_minute = datetime.now().strftime("%H-%M")
|
current_minute = datetime.now().strftime("%H-%M")
|
||||||
ttl = 120 # 2 minutes
|
|
||||||
tpm_key = f'{model_name}:tpm:{current_minute}'
|
tpm_key = f'{model_name}:tpm:{current_minute}'
|
||||||
rpm_key = f'{model_name}:rpm:{current_minute}'
|
rpm_key = f'{model_name}:rpm:{current_minute}'
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
self.cache.increment(tpm_key, total_tokens, expiry=ttl)
|
self.cache.increment(tpm_key, total_tokens)
|
||||||
self.cache.increment(rpm_key, 1, expiry=ttl)
|
self.cache.increment(rpm_key, 1)
|
||||||
|
|
|
@ -8,11 +8,12 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
model_list = [{
|
model_list = [{ # list of model deployments
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
"model": "azure/chatgpt-v-2",
|
"model": "azure/chatgpt-v-2",
|
||||||
|
@ -42,9 +43,19 @@ model_list = [{
|
||||||
"rpm": 9000
|
"rpm": 9000
|
||||||
}]
|
}]
|
||||||
|
|
||||||
router = Router(model_list=model_list)
|
router = Router(model_list=model_list, redis_host=os.getenv("REDIS_HOST"), redis_password=os.getenv("REDIS_PASSWORD"), redis_port=os.getenv("REDIS_PORT"))
|
||||||
|
|
||||||
# openai.ChatCompletion.create replacement
|
completions = []
|
||||||
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}])
|
with ThreadPoolExecutor(max_workers=100) as executor:
|
||||||
|
kwargs = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
}
|
||||||
|
for _ in range(20):
|
||||||
|
future = executor.submit(router.completion, **kwargs)
|
||||||
|
completions.append(future)
|
||||||
|
|
||||||
print(response)
|
# Retrieve the results from the futures
|
||||||
|
results = [future.result() for future in completions]
|
||||||
|
|
||||||
|
print(results)
|
Loading…
Add table
Add a link
Reference in a new issue