fix(lowest_latency.py): add back tpm/rpm checks, configurable time window

This commit is contained in:
Krrish Dholakia 2024-01-10 20:52:01 +05:30
parent ec3e597b61
commit f288b12411
3 changed files with 312 additions and 51 deletions

View file

@ -105,6 +105,7 @@ class Router:
"usage-based-routing", "usage-based-routing",
"latency-based-routing", "latency-based-routing",
] = "simple-shuffle", ] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based routing
) -> None: ) -> None:
self.set_verbose = set_verbose self.set_verbose = set_verbose
self.deployment_names: List = ( self.deployment_names: List = (
@ -217,7 +218,9 @@ class Router:
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "latency-based-routing": elif routing_strategy == "latency-based-routing":
self.lowestlatency_logger = LowestLatencyLoggingHandler( self.lowestlatency_logger = LowestLatencyLoggingHandler(
router_cache=self.cache, model_list=self.model_list router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args,
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
@ -1427,9 +1430,8 @@ class Router:
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_connections=1000, max_keepalive_connections=100
max_keepalive_connections=100 ),
)
), # type: ignore ), # type: ignore
) )
self.cache.set_cache( self.cache.set_cache(
@ -1449,9 +1451,8 @@ class Router:
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(),
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_connections=1000, max_keepalive_connections=100
max_keepalive_connections=100 ),
)
), # type: ignore ), # type: ignore
) )
self.cache.set_cache( self.cache.set_cache(
@ -1471,10 +1472,9 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_connections=1000, max_keepalive_connections=100
max_keepalive_connections=100
) )
) ),
) )
self.cache.set_cache( self.cache.set_cache(
key=cache_key, key=cache_key,
@ -1492,10 +1492,9 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client( http_client=httpx.Client(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_connections=1000, max_keepalive_connections=100
max_keepalive_connections=100
) )
) ),
) )
self.cache.set_cache( self.cache.set_cache(
key=cache_key, key=cache_key,

View file

@ -1,25 +1,46 @@
#### What this does #### #### What this does ####
# picks based on response time (for streaming, this is time to first token) # picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm import ModelResponse
from litellm import token_counter
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 * 60 # 1 hour
class LowestLatencyLoggingHandler(CustomLogger): class LowestLatencyLoggingHandler(CustomLogger):
test_flag: bool = False test_flag: bool = False
logged_success: int = 0 logged_success: int = 0
logged_failure: int = 0 logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list): def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
@ -37,25 +58,64 @@ class LowestLatencyLoggingHandler(CustomLogger):
if model_group is None or id is None: if model_group is None or id is None:
return return
response_ms = end_time - start_time
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
latency_key = f"{model_group}_latency_map" """
{
{model_group}_map: {
id: {
"latency": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
final_value = response_ms
total_tokens = 0
if isinstance(response_obj, ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
total_tokens = response_obj.usage.total_tokens
final_value = float(completion_tokens / response_ms.total_seconds())
# ------------ # ------------
# Update usage # Update usage
# ------------ # ------------
## Latency
request_count_dict = self.router_cache.get_cache(key=latency_key) or {} request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
if id in request_count_dict and isinstance(request_count_dict[id], list):
request_count_dict[id] = request_count_dict[id].append(response_ms)
else:
request_count_dict[id] = [response_ms]
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window if id not in request_count_dict:
request_count_dict[id] = {}
## Latency
request_count_dict[id].setdefault("latency", []).append(final_value)
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
) # reset map within window
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
@ -80,25 +140,64 @@ class LowestLatencyLoggingHandler(CustomLogger):
if model_group is None or id is None: if model_group is None or id is None:
return return
response_ms = end_time - start_time
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
latency_key = f"{model_group}_latency_map" """
{
{model_group}_map: {
id: {
"latency": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
final_value = response_ms
total_tokens = 0
if isinstance(response_obj, ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
total_tokens = response_obj.usage.total_tokens
final_value = float(completion_tokens / response_ms.total_seconds())
# ------------ # ------------
# Update usage # Update usage
# ------------ # ------------
## Latency
request_count_dict = self.router_cache.get_cache(key=latency_key) or {} request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
if id in request_count_dict and isinstance(request_count_dict[id], list):
request_count_dict[id] = request_count_dict[id] + [response_ms]
else:
request_count_dict[id] = [response_ms]
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window if id not in request_count_dict:
request_count_dict[id] = {}
## Latency
request_count_dict[id].setdefault("latency", []).append(final_value)
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
) # reset map within window
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
@ -107,12 +206,18 @@ class LowestLatencyLoggingHandler(CustomLogger):
traceback.print_exc() traceback.print_exc()
pass pass
def get_available_deployments(self, model_group: str, healthy_deployments: list): def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
""" """
Returns a deployment with the lowest latency Returns a deployment with the lowest latency
""" """
# get list of potential deployments # get list of potential deployments
latency_key = f"{model_group}_latency_map" latency_key = f"{model_group}_map"
request_count_dict = self.router_cache.get_cache(key=latency_key) or {} request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
@ -120,6 +225,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
# Find lowest used model # Find lowest used model
# ---------------------- # ----------------------
lowest_latency = float("inf") lowest_latency = float("inf")
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
deployment = None deployment = None
if request_count_dict is None: # base case if request_count_dict is None: # base case
@ -129,9 +240,17 @@ class LowestLatencyLoggingHandler(CustomLogger):
for d in healthy_deployments: for d in healthy_deployments:
## if healthy deployment not yet used ## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments: if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = [0] all_deployments[d["model_info"]["id"]] = {
"latency": [0],
precise_minute: {"tpm": 0, "rpm": 0},
}
for item, item_latency in all_deployments.items(): try:
input_tokens = token_counter(messages=messages, text=input)
except:
input_tokens = 0
for item, item_map in all_deployments.items():
## get the item from model list ## get the item from model list
_deployment = None _deployment = None
for m in healthy_deployments: for m in healthy_deployments:
@ -140,18 +259,38 @@ class LowestLatencyLoggingHandler(CustomLogger):
if _deployment is None: if _deployment is None:
continue # skip to next one continue # skip to next one
# get average latency _deployment_tpm = (
total = 0.0 _deployment.get("tpm", None)
or _deployment.get("litellm_params", {}).get("tpm", None)
or _deployment.get("model_info", {}).get("tpm", None)
or float("inf")
)
_deployment_rpm = (
_deployment.get("rpm", None)
or _deployment.get("litellm_params", {}).get("rpm", None)
or _deployment.get("model_info", {}).get("rpm", None)
or float("inf")
)
item_latency = item_map.get("latency", [])
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
# get average latency
total: float = 0.0
for _call_latency in item_latency: for _call_latency in item_latency:
if isinstance(_call_latency, timedelta): if isinstance(_call_latency, float):
total += float(_call_latency.total_seconds())
elif isinstance(_call_latency, float):
total += _call_latency total += _call_latency
item_latency = total/len(item_latency) item_latency = total / len(item_latency)
if item_latency == 0: if item_latency == 0:
deployment = _deployment deployment = _deployment
break break
elif (
item_tpm + input_tokens > _deployment_tpm
or item_rpm + 1 > _deployment_rpm
): # if user passed in tpm / rpm in the model_list
continue
elif item_latency < lowest_latency: elif item_latency < lowest_latency:
lowest_latency = item_latency lowest_latency = item_latency
deployment = _deployment deployment = _deployment

View file

@ -48,13 +48,55 @@ def test_latency_updated():
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
latency_key = f"{model_group}_latency_map" latency_key = f"{model_group}_map"
assert end_time - start_time == test_cache.get_cache(key=latency_key)[deployment_id][0] assert (
end_time - start_time
== test_cache.get_cache(key=latency_key)[deployment_id]["latency"][0]
)
# test_tpm_rpm_updated() # test_tpm_rpm_updated()
def test_latency_updated_custom_ttl():
"""
Invalidate the cached request.
Test that the cache is empty
"""
test_cache = DualCache()
model_list = []
cache_time = 3
lowest_latency_logger = LowestLatencyLoggingHandler(
router_cache=test_cache, model_list=model_list, routing_args={"ttl": cache_time}
)
model_group = "gpt-3.5-turbo"
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
time.sleep(5)
end_time = time.time()
lowest_latency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
latency_key = f"{model_group}_map"
assert isinstance(test_cache.get_cache(key=latency_key), dict)
time.sleep(cache_time)
assert test_cache.get_cache(key=latency_key) is None
def test_get_available_deployments(): def test_get_available_deployments():
test_cache = DualCache() test_cache = DualCache()
model_list = [ model_list = [
@ -133,6 +175,90 @@ def test_get_available_deployments():
# test_get_available_deployments() # test_get_available_deployments()
def test_get_available_endpoints_tpm_rpm_check():
"""
Pass in list of 2 valid models
Update cache with 1 model clearly being at tpm/rpm limit
assert that only the valid model is returned
"""
test_cache = DualCache()
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "azure/chatgpt-v-2"},
"model_info": {"id": "1234", "rpm": 10},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "azure/chatgpt-v-2"},
"model_info": {"id": "5678", "rpm": 3},
},
]
lowest_latency_logger = LowestLatencyLoggingHandler(
router_cache=test_cache, model_list=model_list
)
model_group = "gpt-3.5-turbo"
## DEPLOYMENT 1 ##
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
for _ in range(3):
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
time.sleep(0.05)
end_time = time.time()
lowest_latency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## DEPLOYMENT 2 ##
deployment_id = "5678"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
for _ in range(3):
start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}}
time.sleep(2)
end_time = time.time()
lowest_latency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
print(
lowest_latency_logger.get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)
)
assert (
lowest_latency_logger.get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)["model_info"]["id"]
== "1234"
)
def test_router_get_available_deployments(): def test_router_get_available_deployments():
""" """
Test if routers 'get_available_deployments' returns the fastest deployment Test if routers 'get_available_deployments' returns the fastest deployment
@ -213,9 +339,6 @@ def test_router_get_available_deployments():
assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2 assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2
# test_get_available_deployments()
# test_router_get_available_deployments() # test_router_get_available_deployments()