diff --git a/litellm/router.py b/litellm/router.py index 311afeb44..99bdf8d46 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -105,6 +105,7 @@ class Router: "usage-based-routing", "latency-based-routing", ] = "simple-shuffle", + routing_strategy_args: dict = {}, # just for latency-based routing ) -> None: self.set_verbose = set_verbose self.deployment_names: List = ( @@ -217,7 +218,7 @@ class Router: litellm.callbacks.append(self.lowesttpm_logger) # type: ignore elif routing_strategy == "latency-based-routing": self.lowestlatency_logger = LowestLatencyLoggingHandler( - router_cache=self.cache, model_list=self.model_list + router_cache=self.cache, model_list=self.model_list, routing_args=routing_strategy_args ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowestlatency_logger) # type: ignore diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index 43e28a8b3..53d1bf3a4 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -1,6 +1,6 @@ #### What this does #### # picks based on response time (for streaming, this is time to first token) - +from pydantic import BaseModel, Extra, Field, root_validator import dotenv, os, requests, random from typing import Optional from datetime import datetime, timedelta @@ -10,16 +10,30 @@ import traceback from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger +class LiteLLMBase(BaseModel): + """ + Implements default functions, all pydantic objects should have. + """ + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + +class RoutingArgs(LiteLLMBase): + ttl: int = 1 * 60 * 60 # 1 hour class LowestLatencyLoggingHandler(CustomLogger): test_flag: bool = False logged_success: int = 0 logged_failure: int = 0 - default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour - def __init__(self, router_cache: DualCache, model_list: list): + def __init__(self, router_cache: DualCache, model_list: list, routing_args: dict={}): self.router_cache = router_cache self.model_list = model_list + self.routing_args = RoutingArgs(**routing_args) def log_success_event(self, kwargs, response_obj, start_time, end_time): try: @@ -55,7 +69,7 @@ class LowestLatencyLoggingHandler(CustomLogger): else: request_count_dict[id] = [response_ms] - self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window + self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl) # reset map within window ### TESTING ### if self.test_flag: @@ -98,7 +112,7 @@ class LowestLatencyLoggingHandler(CustomLogger): else: request_count_dict[id] = [response_ms] - self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window + self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl) # reset map within window ### TESTING ### if self.test_flag: diff --git a/litellm/tests/test_lowest_latency_routing.py b/litellm/tests/test_lowest_latency_routing.py index 6805bfa58..09c29d6a4 100644 --- a/litellm/tests/test_lowest_latency_routing.py +++ b/litellm/tests/test_lowest_latency_routing.py @@ -51,9 +51,46 @@ def test_latency_updated(): latency_key = f"{model_group}_latency_map" assert end_time - start_time == test_cache.get_cache(key=latency_key)[deployment_id][0] - # test_tpm_rpm_updated() +def test_latency_updated_custom_ttl(): + """ + Invalidate the cached request. + + Test that the cache is empty + """ + test_cache = DualCache() + model_list = [] + cache_time = 3 + lowest_latency_logger = LowestLatencyLoggingHandler( + router_cache=test_cache, model_list=model_list, routing_args={"ttl": cache_time} + ) + model_group = "gpt-3.5-turbo" + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + time.sleep(5) + end_time = time.time() + lowest_latency_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + latency_key = f"{model_group}_latency_map" + assert isinstance(test_cache.get_cache(key=latency_key), dict) + time.sleep(cache_time) + assert test_cache.get_cache(key=latency_key) is None + def test_get_available_deployments(): test_cache = DualCache()