forked from phoenix/litellm-mirror
fix(router.py): allow user to control the latency routing time window
This commit is contained in:
parent
2b3fc15fa9
commit
fe632c08a4
3 changed files with 59 additions and 7 deletions
|
@ -105,6 +105,7 @@ class Router:
|
||||||
"usage-based-routing",
|
"usage-based-routing",
|
||||||
"latency-based-routing",
|
"latency-based-routing",
|
||||||
] = "simple-shuffle",
|
] = "simple-shuffle",
|
||||||
|
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||||
) -> None:
|
) -> None:
|
||||||
self.set_verbose = set_verbose
|
self.set_verbose = set_verbose
|
||||||
self.deployment_names: List = (
|
self.deployment_names: List = (
|
||||||
|
@ -217,7 +218,7 @@ class Router:
|
||||||
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
||||||
elif routing_strategy == "latency-based-routing":
|
elif routing_strategy == "latency-based-routing":
|
||||||
self.lowestlatency_logger = LowestLatencyLoggingHandler(
|
self.lowestlatency_logger = LowestLatencyLoggingHandler(
|
||||||
router_cache=self.cache, model_list=self.model_list
|
router_cache=self.cache, model_list=self.model_list, routing_args=routing_strategy_args
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
|
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# picks based on response time (for streaming, this is time to first token)
|
# picks based on response time (for streaming, this is time to first token)
|
||||||
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
import dotenv, os, requests, random
|
import dotenv, os, requests, random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
@ -10,16 +10,30 @@ import traceback
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
class LiteLLMBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Implements default functions, all pydantic objects should have.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
class RoutingArgs(LiteLLMBase):
|
||||||
|
ttl: int = 1 * 60 * 60 # 1 hour
|
||||||
|
|
||||||
class LowestLatencyLoggingHandler(CustomLogger):
|
class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
test_flag: bool = False
|
test_flag: bool = False
|
||||||
logged_success: int = 0
|
logged_success: int = 0
|
||||||
logged_failure: int = 0
|
logged_failure: int = 0
|
||||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
|
||||||
|
|
||||||
def __init__(self, router_cache: DualCache, model_list: list):
|
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict={}):
|
||||||
self.router_cache = router_cache
|
self.router_cache = router_cache
|
||||||
self.model_list = model_list
|
self.model_list = model_list
|
||||||
|
self.routing_args = RoutingArgs(**routing_args)
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
|
@ -55,7 +69,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
else:
|
else:
|
||||||
request_count_dict[id] = [response_ms]
|
request_count_dict[id] = [response_ms]
|
||||||
|
|
||||||
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window
|
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl) # reset map within window
|
||||||
|
|
||||||
### TESTING ###
|
### TESTING ###
|
||||||
if self.test_flag:
|
if self.test_flag:
|
||||||
|
@ -98,7 +112,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
||||||
else:
|
else:
|
||||||
request_count_dict[id] = [response_ms]
|
request_count_dict[id] = [response_ms]
|
||||||
|
|
||||||
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window
|
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl) # reset map within window
|
||||||
|
|
||||||
### TESTING ###
|
### TESTING ###
|
||||||
if self.test_flag:
|
if self.test_flag:
|
||||||
|
|
|
@ -51,9 +51,46 @@ def test_latency_updated():
|
||||||
latency_key = f"{model_group}_latency_map"
|
latency_key = f"{model_group}_latency_map"
|
||||||
assert end_time - start_time == test_cache.get_cache(key=latency_key)[deployment_id][0]
|
assert end_time - start_time == test_cache.get_cache(key=latency_key)[deployment_id][0]
|
||||||
|
|
||||||
|
|
||||||
# test_tpm_rpm_updated()
|
# test_tpm_rpm_updated()
|
||||||
|
|
||||||
|
def test_latency_updated_custom_ttl():
|
||||||
|
"""
|
||||||
|
Invalidate the cached request.
|
||||||
|
|
||||||
|
Test that the cache is empty
|
||||||
|
"""
|
||||||
|
test_cache = DualCache()
|
||||||
|
model_list = []
|
||||||
|
cache_time = 3
|
||||||
|
lowest_latency_logger = LowestLatencyLoggingHandler(
|
||||||
|
router_cache=test_cache, model_list=model_list, routing_args={"ttl": cache_time}
|
||||||
|
)
|
||||||
|
model_group = "gpt-3.5-turbo"
|
||||||
|
deployment_id = "1234"
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "gpt-3.5-turbo",
|
||||||
|
"deployment": "azure/chatgpt-v-2",
|
||||||
|
},
|
||||||
|
"model_info": {"id": deployment_id},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start_time = time.time()
|
||||||
|
response_obj = {"usage": {"total_tokens": 50}}
|
||||||
|
time.sleep(5)
|
||||||
|
end_time = time.time()
|
||||||
|
lowest_latency_logger.log_success_event(
|
||||||
|
response_obj=response_obj,
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
latency_key = f"{model_group}_latency_map"
|
||||||
|
assert isinstance(test_cache.get_cache(key=latency_key), dict)
|
||||||
|
time.sleep(cache_time)
|
||||||
|
assert test_cache.get_cache(key=latency_key) is None
|
||||||
|
|
||||||
|
|
||||||
def test_get_available_deployments():
|
def test_get_available_deployments():
|
||||||
test_cache = DualCache()
|
test_cache = DualCache()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue