mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(tpm_rpm_routing_v2.py): fix tpm rpm routing
This commit is contained in:
parent
ee622e248d
commit
72691e05f4
4 changed files with 456 additions and 14 deletions
|
@ -89,6 +89,13 @@ class InMemoryCache(BaseCache):
|
||||||
return_val.append(val)
|
return_val.append(val)
|
||||||
return return_val
|
return return_val
|
||||||
|
|
||||||
|
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||||
|
# get the value
|
||||||
|
init_value = self.get_cache(key=key) or 0
|
||||||
|
value = init_value + value
|
||||||
|
self.set_cache(key, value, **kwargs)
|
||||||
|
return value
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
return self.get_cache(key=key, **kwargs)
|
return self.get_cache(key=key, **kwargs)
|
||||||
|
|
||||||
|
@ -198,6 +205,38 @@ class RedisCache(BaseCache):
|
||||||
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
|
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||||
|
_redis_client = self.redis_client
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
result = _redis_client.incr(name=key, amount=value)
|
||||||
|
## LOGGING ##
|
||||||
|
end_time = time.time()
|
||||||
|
_duration = end_time - start_time
|
||||||
|
asyncio.create_task(
|
||||||
|
self.service_logger_obj.service_success_hook(
|
||||||
|
service=ServiceTypes.REDIS,
|
||||||
|
duration=_duration,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
## LOGGING ##
|
||||||
|
end_time = time.time()
|
||||||
|
_duration = end_time - start_time
|
||||||
|
asyncio.create_task(
|
||||||
|
self.service_logger_obj.async_service_failure_hook(
|
||||||
|
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||||
|
)
|
||||||
|
)
|
||||||
|
verbose_logger.error(
|
||||||
|
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
||||||
|
str(e),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
|
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
|
@ -1093,6 +1132,30 @@ class DualCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(e)
|
print_verbose(e)
|
||||||
|
|
||||||
|
def increment_cache(
|
||||||
|
self, key, value: int, local_only: bool = False, **kwargs
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Key - the key in cache
|
||||||
|
|
||||||
|
Value - int - the value you want to increment by
|
||||||
|
|
||||||
|
Returns - int - the incremented value
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result: int = value
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
||||||
|
|
||||||
|
if self.redis_cache is not None and local_only == False:
|
||||||
|
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
def get_cache(self, key, local_only: bool = False, **kwargs):
|
def get_cache(self, key, local_only: bool = False, **kwargs):
|
||||||
# Try to fetch from in-memory cache first
|
# Try to fetch from in-memory cache first
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -143,26 +143,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
# Setup values
|
# Setup values
|
||||||
# ------------
|
# ------------
|
||||||
dt = get_utc_datetime()
|
dt = get_utc_datetime()
|
||||||
current_minute = dt.strftime("%H-%M")
|
current_minute = dt.strftime(
|
||||||
tpm_key = f"{model_group}:tpm:{current_minute}"
|
"%H-%M"
|
||||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
) # use the same timezone regardless of system clock
|
||||||
|
|
||||||
|
tpm_key = f"{id}:tpm:{current_minute}"
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
|
# update cache
|
||||||
|
|
||||||
## TPM
|
## TPM
|
||||||
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
self.router_cache.increment_cache(key=tpm_key, value=total_tokens)
|
||||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
|
||||||
|
|
||||||
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
|
|
||||||
|
|
||||||
## RPM
|
|
||||||
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
|
||||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
|
||||||
|
|
||||||
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
|
|
||||||
|
|
||||||
### TESTING ###
|
### TESTING ###
|
||||||
if self.test_flag:
|
if self.test_flag:
|
||||||
self.logged_success += 1
|
self.logged_success += 1
|
||||||
|
|
387
litellm/tests/test_tpm_rpm_routing_v2.py
Normal file
387
litellm/tests/test_tpm_rpm_routing_v2.py
Normal file
|
@ -0,0 +1,387 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2'
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
from litellm import Router
|
||||||
|
import litellm
|
||||||
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import (
|
||||||
|
LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler,
|
||||||
|
)
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
### UNIT TESTS FOR TPM/RPM ROUTING ###
|
||||||
|
|
||||||
|
|
||||||
|
def test_tpm_rpm_updated():
|
||||||
|
test_cache = DualCache()
|
||||||
|
model_list = []
|
||||||
|
lowest_tpm_logger = LowestTPMLoggingHandler(
|
||||||
|
router_cache=test_cache, model_list=model_list
|
||||||
|
)
|
||||||
|
model_group = "gpt-3.5-turbo"
|
||||||
|
deployment_id = "1234"
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "gpt-3.5-turbo",
|
||||||
|
"deployment": "azure/chatgpt-v-2",
|
||||||
|
},
|
||||||
|
"model_info": {"id": deployment_id},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start_time = time.time()
|
||||||
|
response_obj = {"usage": {"total_tokens": 50}}
|
||||||
|
end_time = time.time()
|
||||||
|
lowest_tpm_logger.log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
current_minute = datetime.now().strftime("%H-%M")
|
||||||
|
tpm_count_api_key = f"{model_group}:tpm:{current_minute}"
|
||||||
|
rpm_count_api_key = f"{model_group}:rpm:{current_minute}"
|
||||||
|
assert (
|
||||||
|
response_obj["usage"]["total_tokens"]
|
||||||
|
== test_cache.get_cache(key=tpm_count_api_key)[deployment_id]
|
||||||
|
)
|
||||||
|
assert 1 == test_cache.get_cache(key=rpm_count_api_key)[deployment_id]
|
||||||
|
|
||||||
|
|
||||||
|
# test_tpm_rpm_updated()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_available_deployments():
|
||||||
|
test_cache = DualCache()
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "azure/chatgpt-v-2"},
|
||||||
|
"model_info": {"id": "1234"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "azure/chatgpt-v-2"},
|
||||||
|
"model_info": {"id": "5678"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
lowest_tpm_logger = LowestTPMLoggingHandler(
|
||||||
|
router_cache=test_cache, model_list=model_list
|
||||||
|
)
|
||||||
|
model_group = "gpt-3.5-turbo"
|
||||||
|
## DEPLOYMENT 1 ##
|
||||||
|
deployment_id = "1234"
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "gpt-3.5-turbo",
|
||||||
|
"deployment": "azure/chatgpt-v-2",
|
||||||
|
},
|
||||||
|
"model_info": {"id": deployment_id},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start_time = time.time()
|
||||||
|
response_obj = {"usage": {"total_tokens": 50}}
|
||||||
|
end_time = time.time()
|
||||||
|
lowest_tpm_logger.log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
## DEPLOYMENT 2 ##
|
||||||
|
deployment_id = "5678"
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "gpt-3.5-turbo",
|
||||||
|
"deployment": "azure/chatgpt-v-2",
|
||||||
|
},
|
||||||
|
"model_info": {"id": deployment_id},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start_time = time.time()
|
||||||
|
response_obj = {"usage": {"total_tokens": 20}}
|
||||||
|
end_time = time.time()
|
||||||
|
lowest_tpm_logger.log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
## CHECK WHAT'S SELECTED ##
|
||||||
|
print(
|
||||||
|
lowest_tpm_logger.get_available_deployments(
|
||||||
|
model_group=model_group,
|
||||||
|
healthy_deployments=model_list,
|
||||||
|
input=["Hello world"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
lowest_tpm_logger.get_available_deployments(
|
||||||
|
model_group=model_group,
|
||||||
|
healthy_deployments=model_list,
|
||||||
|
input=["Hello world"],
|
||||||
|
)["model_info"]["id"]
|
||||||
|
== "5678"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# test_get_available_deployments()
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_get_available_deployments():
|
||||||
|
"""
|
||||||
|
Test if routers 'get_available_deployments' returns the least busy deployment
|
||||||
|
"""
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
"rpm": 1440,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy="usage-based-routing",
|
||||||
|
set_verbose=False,
|
||||||
|
num_retries=3,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
print(f"router id's: {router.get_model_ids()}")
|
||||||
|
## DEPLOYMENT 1 ##
|
||||||
|
deployment_id = 1
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "azure-model",
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start_time = time.time()
|
||||||
|
response_obj = {"usage": {"total_tokens": 50}}
|
||||||
|
end_time = time.time()
|
||||||
|
router.lowesttpm_logger.log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
## DEPLOYMENT 2 ##
|
||||||
|
deployment_id = 2
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "azure-model",
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start_time = time.time()
|
||||||
|
response_obj = {"usage": {"total_tokens": 20}}
|
||||||
|
end_time = time.time()
|
||||||
|
router.lowesttpm_logger.log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
## CHECK WHAT'S SELECTED ##
|
||||||
|
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
|
||||||
|
assert (
|
||||||
|
router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# test_get_available_deployments()
|
||||||
|
# test_router_get_available_deployments()
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_skip_rate_limited_deployments():
|
||||||
|
"""
|
||||||
|
Test if routers 'get_available_deployments' raises No Models Available error if max tpm would be reached by message
|
||||||
|
"""
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
"tpm": 1440,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy="usage-based-routing",
|
||||||
|
set_verbose=False,
|
||||||
|
num_retries=3,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
## DEPLOYMENT 1 ##
|
||||||
|
deployment_id = 1
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "azure-model",
|
||||||
|
},
|
||||||
|
"model_info": {"id": deployment_id},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start_time = time.time()
|
||||||
|
response_obj = {"usage": {"total_tokens": 1439}}
|
||||||
|
end_time = time.time()
|
||||||
|
router.lowesttpm_logger.log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
## CHECK WHAT'S SELECTED ##
|
||||||
|
# print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model"))
|
||||||
|
try:
|
||||||
|
router.get_available_deployment(
|
||||||
|
model="azure-model",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
pytest.fail(f"Should have raised No Models Available error")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An exception occurred! {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_deployment_tpm_zero():
|
||||||
|
import litellm
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
"tpm": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy="usage-based-routing",
|
||||||
|
cache_responses=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = "gpt-3.5-turbo"
|
||||||
|
messages = [{"content": "Hello, how are you?", "role": "user"}]
|
||||||
|
try:
|
||||||
|
router.get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
pytest.fail(f"Should have raised No Models Available error")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"it worked - {str(e)}! \n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_completion_streaming():
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
|
||||||
|
]
|
||||||
|
model = "azure-model"
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
"rpm": 1440,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy="usage-based-routing",
|
||||||
|
set_verbose=False,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
### Make 3 calls, test if 3rd call goes to lowest tpm deployment
|
||||||
|
|
||||||
|
## CALL 1+2
|
||||||
|
tasks = []
|
||||||
|
response = None
|
||||||
|
final_response = None
|
||||||
|
for _ in range(2):
|
||||||
|
tasks.append(router.acompletion(model=model, messages=messages))
|
||||||
|
response = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if response is not None:
|
||||||
|
## CALL 3
|
||||||
|
await asyncio.sleep(1) # let the token update happen
|
||||||
|
current_minute = datetime.now().strftime("%H-%M")
|
||||||
|
picked_deployment = router.lowesttpm_logger.get_available_deployments(
|
||||||
|
model_group=model,
|
||||||
|
healthy_deployments=router.healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
final_response = await router.acompletion(model=model, messages=messages)
|
||||||
|
print(f"min deployment id: {picked_deployment}")
|
||||||
|
tpm_key = f"{model}:tpm:{current_minute}"
|
||||||
|
rpm_key = f"{model}:rpm:{current_minute}"
|
||||||
|
|
||||||
|
tpm_dict = router.cache.get_cache(key=tpm_key)
|
||||||
|
print(f"tpm_dict: {tpm_dict}")
|
||||||
|
rpm_dict = router.cache.get_cache(key=rpm_key)
|
||||||
|
print(f"rpm_dict: {rpm_dict}")
|
||||||
|
print(f"model id: {final_response._hidden_params['model_id']}")
|
||||||
|
assert (
|
||||||
|
final_response._hidden_params["model_id"]
|
||||||
|
== picked_deployment["model_info"]["id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_router_completion_streaming())
|
Loading…
Add table
Add a link
Reference in a new issue