fix(tpm_rpm_routing_v2.py): fix tpm rpm routing

This commit is contained in:
Krrish Dholakia 2024-04-18 20:01:07 -07:00
parent ee622e248d
commit 72691e05f4
4 changed files with 456 additions and 14 deletions

View file

@ -89,6 +89,13 @@ class InMemoryCache(BaseCache):
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
@ -198,6 +205,38 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
)
def increment_cache(self, key, value: int, **kwargs) -> int:
_redis_client = self.redis_client
start_time = time.time()
try:
result = _redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
)
)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
verbose_logger.error(
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
raise e
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
start_time = time.time()
try:
@ -1093,6 +1132,30 @@ class DualCache(BaseCache):
except Exception as e:
print_verbose(e)
def increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
result = self.redis_cache.increment_cache(key, value, **kwargs)
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
raise e
def get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first
try:

View file

@ -143,26 +143,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
self.router_cache.increment_cache(key=tpm_key, value=total_tokens)
### TESTING ###
if self.test_flag:
self.logged_success += 1

View file

@ -0,0 +1,387 @@
#### What this tests ####
# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2'
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm import Router
import litellm
from litellm.router_strategy.lowest_tpm_rpm_v2 import (
LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler,
)
from litellm.caching import DualCache
### UNIT TESTS FOR TPM/RPM ROUTING ###
def test_tpm_rpm_updated():
test_cache = DualCache()
model_list = []
lowest_tpm_logger = LowestTPMLoggingHandler(
router_cache=test_cache, model_list=model_list
)
model_group = "gpt-3.5-turbo"
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
end_time = time.time()
lowest_tpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
current_minute = datetime.now().strftime("%H-%M")
tpm_count_api_key = f"{model_group}:tpm:{current_minute}"
rpm_count_api_key = f"{model_group}:rpm:{current_minute}"
assert (
response_obj["usage"]["total_tokens"]
== test_cache.get_cache(key=tpm_count_api_key)[deployment_id]
)
assert 1 == test_cache.get_cache(key=rpm_count_api_key)[deployment_id]
# test_tpm_rpm_updated()
def test_get_available_deployments():
test_cache = DualCache()
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "azure/chatgpt-v-2"},
"model_info": {"id": "1234"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "azure/chatgpt-v-2"},
"model_info": {"id": "5678"},
},
]
lowest_tpm_logger = LowestTPMLoggingHandler(
router_cache=test_cache, model_list=model_list
)
model_group = "gpt-3.5-turbo"
## DEPLOYMENT 1 ##
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
end_time = time.time()
lowest_tpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## DEPLOYMENT 2 ##
deployment_id = "5678"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}}
end_time = time.time()
lowest_tpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
print(
lowest_tpm_logger.get_available_deployments(
model_group=model_group,
healthy_deployments=model_list,
input=["Hello world"],
)
)
assert (
lowest_tpm_logger.get_available_deployments(
model_group=model_group,
healthy_deployments=model_list,
input=["Hello world"],
)["model_info"]["id"]
== "5678"
)
# test_get_available_deployments()
def test_router_get_available_deployments():
"""
Test if routers 'get_available_deployments' returns the least busy deployment
"""
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing",
set_verbose=False,
num_retries=3,
) # type: ignore
print(f"router id's: {router.get_model_ids()}")
## DEPLOYMENT 1 ##
deployment_id = 1
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": 1},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
end_time = time.time()
router.lowesttpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## DEPLOYMENT 2 ##
deployment_id = 2
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": 2},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}}
end_time = time.time()
router.lowesttpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
assert (
router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2"
)
# test_get_available_deployments()
# test_router_get_available_deployments()
def test_router_skip_rate_limited_deployments():
"""
Test if routers 'get_available_deployments' raises No Models Available error if max tpm would be reached by message
"""
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"tpm": 1440,
},
"model_info": {"id": 1},
},
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing",
set_verbose=False,
num_retries=3,
) # type: ignore
## DEPLOYMENT 1 ##
deployment_id = 1
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "azure-model",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 1439}}
end_time = time.time()
router.lowesttpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
try:
router.get_available_deployment(
model="azure-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
pytest.fail(f"Should have raised No Models Available error")
except Exception as e:
print(f"An exception occurred! {str(e)}")
def test_single_deployment_tpm_zero():
import litellm
import os
from datetime import datetime
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
"tpm": 0,
},
}
]
router = litellm.Router(
model_list=model_list,
routing_strategy="usage-based-routing",
cache_responses=True,
)
model = "gpt-3.5-turbo"
messages = [{"content": "Hello, how are you?", "role": "user"}]
try:
router.get_available_deployment(
model=model,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
pytest.fail(f"Should have raised No Models Available error")
except Exception as e:
print(f"it worked - {str(e)}! \n{traceback.format_exc()}")
@pytest.mark.asyncio
async def test_router_completion_streaming():
messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
]
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing",
set_verbose=False,
) # type: ignore
### Make 3 calls, test if 3rd call goes to lowest tpm deployment
## CALL 1+2
tasks = []
response = None
final_response = None
for _ in range(2):
tasks.append(router.acompletion(model=model, messages=messages))
response = await asyncio.gather(*tasks)
if response is not None:
## CALL 3
await asyncio.sleep(1) # let the token update happen
current_minute = datetime.now().strftime("%H-%M")
picked_deployment = router.lowesttpm_logger.get_available_deployments(
model_group=model,
healthy_deployments=router.healthy_deployments,
messages=messages,
)
final_response = await router.acompletion(model=model, messages=messages)
print(f"min deployment id: {picked_deployment}")
tpm_key = f"{model}:tpm:{current_minute}"
rpm_key = f"{model}:rpm:{current_minute}"
tpm_dict = router.cache.get_cache(key=tpm_key)
print(f"tpm_dict: {tpm_dict}")
rpm_dict = router.cache.get_cache(key=rpm_key)
print(f"rpm_dict: {rpm_dict}")
print(f"model id: {final_response._hidden_params['model_id']}")
assert (
final_response._hidden_params["model_id"]
== picked_deployment["model_info"]["id"]
)
# asyncio.run(test_router_completion_streaming())