forked from phoenix/litellm-mirror
fix(router.py): allow user to control the latency routing time window
This commit is contained in:
parent
2b3fc15fa9
commit
fe632c08a4
3 changed files with 59 additions and 7 deletions
|
@ -105,6 +105,7 @@ class Router:
|
|||
"usage-based-routing",
|
||||
"latency-based-routing",
|
||||
] = "simple-shuffle",
|
||||
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||
) -> None:
|
||||
self.set_verbose = set_verbose
|
||||
self.deployment_names: List = (
|
||||
|
@ -217,7 +218,7 @@ class Router:
|
|||
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
||||
elif routing_strategy == "latency-based-routing":
|
||||
self.lowestlatency_logger = LowestLatencyLoggingHandler(
|
||||
router_cache=self.cache, model_list=self.model_list
|
||||
router_cache=self.cache, model_list=self.model_list, routing_args=routing_strategy_args
|
||||
)
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#### What this does ####
|
||||
# picks based on response time (for streaming, this is time to first token)
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
import dotenv, os, requests, random
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
@ -10,16 +10,30 @@ import traceback
|
|||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
"""
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
class RoutingArgs(LiteLLMBase):
|
||||
ttl: int = 1 * 60 * 60 # 1 hour
|
||||
|
||||
class LowestLatencyLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
|
||||
def __init__(self, router_cache: DualCache, model_list: list):
|
||||
def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict={}):
|
||||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
self.routing_args = RoutingArgs(**routing_args)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
|
@ -55,7 +69,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
else:
|
||||
request_count_dict[id] = [response_ms]
|
||||
|
||||
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window
|
||||
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl) # reset map within window
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
|
@ -98,7 +112,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
|
|||
else:
|
||||
request_count_dict[id] = [response_ms]
|
||||
|
||||
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window
|
||||
self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl) # reset map within window
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
|
|
|
@ -51,9 +51,46 @@ def test_latency_updated():
|
|||
latency_key = f"{model_group}_latency_map"
|
||||
assert end_time - start_time == test_cache.get_cache(key=latency_key)[deployment_id][0]
|
||||
|
||||
|
||||
# test_tpm_rpm_updated()
|
||||
|
||||
def test_latency_updated_custom_ttl():
|
||||
"""
|
||||
Invalidate the cached request.
|
||||
|
||||
Test that the cache is empty
|
||||
"""
|
||||
test_cache = DualCache()
|
||||
model_list = []
|
||||
cache_time = 3
|
||||
lowest_latency_logger = LowestLatencyLoggingHandler(
|
||||
router_cache=test_cache, model_list=model_list, routing_args={"ttl": cache_time}
|
||||
)
|
||||
model_group = "gpt-3.5-turbo"
|
||||
deployment_id = "1234"
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"model_group": "gpt-3.5-turbo",
|
||||
"deployment": "azure/chatgpt-v-2",
|
||||
},
|
||||
"model_info": {"id": deployment_id},
|
||||
}
|
||||
}
|
||||
start_time = time.time()
|
||||
response_obj = {"usage": {"total_tokens": 50}}
|
||||
time.sleep(5)
|
||||
end_time = time.time()
|
||||
lowest_latency_logger.log_success_event(
|
||||
response_obj=response_obj,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
latency_key = f"{model_group}_latency_map"
|
||||
assert isinstance(test_cache.get_cache(key=latency_key), dict)
|
||||
time.sleep(cache_time)
|
||||
assert test_cache.get_cache(key=latency_key) is None
|
||||
|
||||
|
||||
def test_get_available_deployments():
|
||||
test_cache = DualCache()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue