fix(lowest_latency.py): add back tpm/rpm checks, configurable time window

This commit is contained in:
Krrish Dholakia 2024-01-10 20:52:01 +05:30
parent ec3e597b61
commit f288b12411
3 changed files with 312 additions and 51 deletions

View file

@ -105,6 +105,7 @@ class Router:
"usage-based-routing",
"latency-based-routing",
] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based routing
) -> None:
self.set_verbose = set_verbose
self.deployment_names: List = (
@ -217,7 +218,9 @@ class Router:
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "latency-based-routing":
self.lowestlatency_logger = LowestLatencyLoggingHandler(
router_cache=self.cache, model_list=self.model_list
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
@ -1427,9 +1430,8 @@ class Router:
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits(
max_connections=1000,
max_keepalive_connections=100
)
max_connections=1000, max_keepalive_connections=100
),
), # type: ignore
)
self.cache.set_cache(
@ -1449,9 +1451,8 @@ class Router:
http_client=httpx.Client(
transport=CustomHTTPTransport(),
limits=httpx.Limits(
max_connections=1000,
max_keepalive_connections=100
)
max_connections=1000, max_keepalive_connections=100
),
), # type: ignore
)
self.cache.set_cache(
@ -1471,10 +1472,9 @@ class Router:
max_retries=max_retries,
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000,
max_keepalive_connections=100
max_connections=1000, max_keepalive_connections=100
)
)
),
)
self.cache.set_cache(
key=cache_key,
@ -1492,10 +1492,9 @@ class Router:
max_retries=max_retries,
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000,
max_keepalive_connections=100
max_connections=1000, max_keepalive_connections=100
)
)
),
)
self.cache.set_cache(
key=cache_key,

View file

@ -1,25 +1,46 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random
from typing import Optional
from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm import ModelResponse
from litellm import token_counter
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 * 60 # 1 hour
class LowestLatencyLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list):
def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache
self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
@ -37,25 +58,64 @@ class LowestLatencyLoggingHandler(CustomLogger):
if model_group is None or id is None:
return
response_ms = end_time - start_time
# ------------
# Setup values
# ------------
latency_key = f"{model_group}_latency_map"
"""
{
{model_group}_map: {
id: {
"latency": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
final_value = response_ms
total_tokens = 0
if isinstance(response_obj, ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
total_tokens = response_obj.usage.total_tokens
final_value = float(completion_tokens / response_ms.total_seconds())
# ------------
# Update usage
# ------------
## Latency
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
if id in request_count_dict and isinstance(request_count_dict[id], list):
request_count_dict[id] = request_count_dict[id].append(response_ms)
else:
request_count_dict[id] = [response_ms]
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window
if id not in request_count_dict:
request_count_dict[id] = {}
## Latency
request_count_dict[id].setdefault("latency", []).append(final_value)
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
) # reset map within window
### TESTING ###
if self.test_flag:
@ -80,25 +140,64 @@ class LowestLatencyLoggingHandler(CustomLogger):
if model_group is None or id is None:
return
response_ms = end_time - start_time
# ------------
# Setup values
# ------------
latency_key = f"{model_group}_latency_map"
"""
{
{model_group}_map: {
id: {
"latency": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
final_value = response_ms
total_tokens = 0
if isinstance(response_obj, ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
total_tokens = response_obj.usage.total_tokens
final_value = float(completion_tokens / response_ms.total_seconds())
# ------------
# Update usage
# ------------
## Latency
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
if id in request_count_dict and isinstance(request_count_dict[id], list):
request_count_dict[id] = request_count_dict[id] + [response_ms]
else:
request_count_dict[id] = [response_ms]
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window
if id not in request_count_dict:
request_count_dict[id] = {}
## Latency
request_count_dict[id].setdefault("latency", []).append(final_value)
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
) # reset map within window
### TESTING ###
if self.test_flag:
@ -107,12 +206,18 @@ class LowestLatencyLoggingHandler(CustomLogger):
traceback.print_exc()
pass
def get_available_deployments(self, model_group: str, healthy_deployments: list):
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Returns a deployment with the lowest latency
"""
# get list of potential deployments
latency_key = f"{model_group}_latency_map"
latency_key = f"{model_group}_map"
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
@ -120,6 +225,12 @@ class LowestLatencyLoggingHandler(CustomLogger):
# Find lowest used model
# ----------------------
lowest_latency = float("inf")
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
deployment = None
if request_count_dict is None: # base case
@ -129,9 +240,17 @@ class LowestLatencyLoggingHandler(CustomLogger):
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = [0]
all_deployments[d["model_info"]["id"]] = {
"latency": [0],
precise_minute: {"tpm": 0, "rpm": 0},
}
for item, item_latency in all_deployments.items():
try:
input_tokens = token_counter(messages=messages, text=input)
except:
input_tokens = 0
for item, item_map in all_deployments.items():
## get the item from model list
_deployment = None
for m in healthy_deployments:
@ -140,18 +259,38 @@ class LowestLatencyLoggingHandler(CustomLogger):
if _deployment is None:
continue # skip to next one
# get average latency
total = 0.0
_deployment_tpm = (
_deployment.get("tpm", None)
or _deployment.get("litellm_params", {}).get("tpm", None)
or _deployment.get("model_info", {}).get("tpm", None)
or float("inf")
)
_deployment_rpm = (
_deployment.get("rpm", None)
or _deployment.get("litellm_params", {}).get("rpm", None)
or _deployment.get("model_info", {}).get("rpm", None)
or float("inf")
)
item_latency = item_map.get("latency", [])
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
# get average latency
total: float = 0.0
for _call_latency in item_latency:
if isinstance(_call_latency, timedelta):
total += float(_call_latency.total_seconds())
elif isinstance(_call_latency, float):
if isinstance(_call_latency, float):
total += _call_latency
item_latency = total/len(item_latency)
item_latency = total / len(item_latency)
if item_latency == 0:
deployment = _deployment
break
elif (
item_tpm + input_tokens > _deployment_tpm
or item_rpm + 1 > _deployment_rpm
): # if user passed in tpm / rpm in the model_list
continue
elif item_latency < lowest_latency:
lowest_latency = item_latency
deployment = _deployment

View file

@ -48,13 +48,55 @@ def test_latency_updated():
start_time=start_time,
end_time=end_time,
)
latency_key = f"{model_group}_latency_map"
assert end_time - start_time == test_cache.get_cache(key=latency_key)[deployment_id][0]
latency_key = f"{model_group}_map"
assert (
end_time - start_time
== test_cache.get_cache(key=latency_key)[deployment_id]["latency"][0]
)
# test_tpm_rpm_updated()
def test_latency_updated_custom_ttl():
"""
Invalidate the cached request.
Test that the cache is empty
"""
test_cache = DualCache()
model_list = []
cache_time = 3
lowest_latency_logger = LowestLatencyLoggingHandler(
router_cache=test_cache, model_list=model_list, routing_args={"ttl": cache_time}
)
model_group = "gpt-3.5-turbo"
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
time.sleep(5)
end_time = time.time()
lowest_latency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
latency_key = f"{model_group}_map"
assert isinstance(test_cache.get_cache(key=latency_key), dict)
time.sleep(cache_time)
assert test_cache.get_cache(key=latency_key) is None
def test_get_available_deployments():
test_cache = DualCache()
model_list = [
@ -133,6 +175,90 @@ def test_get_available_deployments():
# test_get_available_deployments()
def test_get_available_endpoints_tpm_rpm_check():
"""
Pass in list of 2 valid models
Update cache with 1 model clearly being at tpm/rpm limit
assert that only the valid model is returned
"""
test_cache = DualCache()
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "azure/chatgpt-v-2"},
"model_info": {"id": "1234", "rpm": 10},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "azure/chatgpt-v-2"},
"model_info": {"id": "5678", "rpm": 3},
},
]
lowest_latency_logger = LowestLatencyLoggingHandler(
router_cache=test_cache, model_list=model_list
)
model_group = "gpt-3.5-turbo"
## DEPLOYMENT 1 ##
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
for _ in range(3):
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
time.sleep(0.05)
end_time = time.time()
lowest_latency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## DEPLOYMENT 2 ##
deployment_id = "5678"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
for _ in range(3):
start_time = time.time()
response_obj = {"usage": {"total_tokens": 20}}
time.sleep(2)
end_time = time.time()
lowest_latency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
## CHECK WHAT'S SELECTED ##
print(
lowest_latency_logger.get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)
)
assert (
lowest_latency_logger.get_available_deployments(
model_group=model_group, healthy_deployments=model_list
)["model_info"]["id"]
== "1234"
)
def test_router_get_available_deployments():
"""
Test if routers 'get_available_deployments' returns the fastest deployment
@ -213,9 +339,6 @@ def test_router_get_available_deployments():
assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2
# test_get_available_deployments()
# test_router_get_available_deployments()