diff --git a/litellm/router.py b/litellm/router.py index 99e2435ac..3f2bef476 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2958,6 +2958,7 @@ class Router: if ( self.routing_strategy != "usage-based-routing-v2" and self.routing_strategy != "simple-shuffle" + and self.routing_strategy != "cost-based-routing" ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. return self.get_available_deployment( model=model, @@ -3014,6 +3015,16 @@ class Router: messages=messages, input=input, ) + if ( + self.routing_strategy == "cost-based-routing" + and self.lowestcost_logger is not None + ): + deployment = await self.lowestcost_logger.async_get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, + messages=messages, + input=input, + ) elif self.routing_strategy == "simple-shuffle": # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm ############## Check if we can do a RPM/TPM based weighted pick ################# @@ -3184,15 +3195,6 @@ class Router: messages=messages, input=input, ) - elif ( - self.routing_strategy == "cost-based-routing" - and self.lowestcost_logger is not None - ): - deployment = self.lowestcost_logger.get_available_deployments( - model_group=model, - healthy_deployments=healthy_deployments, - request_kwargs=request_kwargs, - ) if deployment is None: verbose_router_logger.info( f"get_available_deployment for model: {model}, No deployment available" diff --git a/litellm/router_strategy/lowest_cost.py b/litellm/router_strategy/lowest_cost.py index 44b49378d..2d010fb4f 100644 --- a/litellm/router_strategy/lowest_cost.py +++ b/litellm/router_strategy/lowest_cost.py @@ -40,7 +40,7 @@ class LowestCostLoggingHandler(CustomLogger): self.router_cache = router_cache self.model_list = model_list - def log_success_event(self, kwargs, response_obj, start_time, end_time): + async def log_success_event(self, kwargs, response_obj, start_time, end_time): try: """ Update usage on success @@ -90,7 +90,11 @@ class LowestCostLoggingHandler(CustomLogger): # Update usage # ------------ - request_count_dict = self.router_cache.get_cache(key=cost_key) or {} + request_count_dict = ( + await self.router_cache.async_get_cache(key=cost_key) or {} + ) + + # check local result first if id not in request_count_dict: request_count_dict[id] = {} @@ -111,7 +115,9 @@ class LowestCostLoggingHandler(CustomLogger): request_count_dict[id][precise_minute].get("rpm", 0) + 1 ) - self.router_cache.set_cache(key=cost_key, value=request_count_dict) + await self.router_cache.async_set_cache( + key=cost_key, value=request_count_dict + ) ### TESTING ### if self.test_flag: @@ -172,7 +178,9 @@ class LowestCostLoggingHandler(CustomLogger): # Update usage # ------------ - request_count_dict = self.router_cache.get_cache(key=cost_key) or {} + request_count_dict = ( + await self.router_cache.async_get_cache(key=cost_key) or {} + ) if id not in request_count_dict: request_count_dict[id] = {} @@ -189,7 +197,7 @@ class LowestCostLoggingHandler(CustomLogger): request_count_dict[id][precise_minute].get("rpm", 0) + 1 ) - self.router_cache.set_cache( + await self.router_cache.async_set_cache( key=cost_key, value=request_count_dict ) # reset map within window @@ -200,7 +208,7 @@ class LowestCostLoggingHandler(CustomLogger): traceback.print_exc() pass - def get_available_deployments( + async def async_get_available_deployments( self, model_group: str, healthy_deployments: list, @@ -213,7 +221,7 @@ class LowestCostLoggingHandler(CustomLogger): """ cost_key = f"{model_group}_map" - request_count_dict = self.router_cache.get_cache(key=cost_key) or {} + request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {} # ----------------------- # Find lowest used model