fix(router.py): allow user to control the latency routing time window

This commit is contained in:
Krrish Dholakia 2024-01-10 08:48:22 +05:30
parent 2b3fc15fa9
commit fe632c08a4
3 changed files with 59 additions and 7 deletions

View file

@ -105,6 +105,7 @@ class Router:
"usage-based-routing",
"latency-based-routing",
] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based routing
) -> None:
self.set_verbose = set_verbose
self.deployment_names: List = (
@ -217,7 +218,7 @@ class Router:
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "latency-based-routing":
self.lowestlatency_logger = LowestLatencyLoggingHandler(
router_cache=self.cache, model_list=self.model_list
router_cache=self.cache, model_list=self.model_list, routing_args=routing_strategy_args
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore

View file

@ -1,6 +1,6 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random
from typing import Optional
from datetime import datetime, timedelta
@ -10,16 +10,30 @@ import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 * 60 # 1 hour
class LowestLatencyLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list):
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict={}):
self.router_cache = router_cache
self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
@ -55,7 +69,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
else:
request_count_dict[id] = [response_ms]
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl) # reset map within window
### TESTING ###
if self.test_flag:
@ -98,7 +112,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
else:
request_count_dict[id] = [response_ms]
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl) # reset map within window
### TESTING ###
if self.test_flag:

View file

@ -51,9 +51,46 @@ def test_latency_updated():
latency_key = f"{model_group}_latency_map"
assert end_time - start_time == test_cache.get_cache(key=latency_key)[deployment_id][0]
# test_tpm_rpm_updated()
def test_latency_updated_custom_ttl():
"""
Invalidate the cached request.
Test that the cache is empty
"""
test_cache = DualCache()
model_list = []
cache_time = 3
lowest_latency_logger = LowestLatencyLoggingHandler(
router_cache=test_cache, model_list=model_list, routing_args={"ttl": cache_time}
)
model_group = "gpt-3.5-turbo"
deployment_id = "1234"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": deployment_id},
}
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": 50}}
time.sleep(5)
end_time = time.time()
lowest_latency_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
latency_key = f"{model_group}_latency_map"
assert isinstance(test_cache.get_cache(key=latency_key), dict)
time.sleep(cache_time)
assert test_cache.get_cache(key=latency_key) is None
def test_get_available_deployments():
test_cache = DualCache()