diff --git a/litellm/router_strategy/least_busy.py b/litellm/router_strategy/least_busy.py index 8b608d463..64a0aa99a 100644 --- a/litellm/router_strategy/least_busy.py +++ b/litellm/router_strategy/least_busy.py @@ -16,6 +16,10 @@ from litellm.integrations.custom_logger import CustomLogger class LeastBusyLoggingHandler(CustomLogger): + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + def __init__(self, router_cache: DualCache): self.router_cache = router_cache self.mapping_deployment_to_id: dict = {} @@ -50,6 +54,63 @@ class LeastBusyLoggingHandler(CustomLogger): except Exception as e: pass + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) + request_count_dict[id] = request_count_dict.get(id) - 1 + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + pass + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) + request_count_dict[id] = request_count_dict.get(id) - 1 + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) + + ### TESTING ### + if self.test_flag: + self.logged_failure += 1 + except Exception as e: + pass + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: if kwargs["litellm_params"].get("metadata") is None: @@ -72,6 +133,10 @@ class LeastBusyLoggingHandler(CustomLogger): self.router_cache.set_cache( key=request_count_api_key, value=request_count_dict ) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 except Exception as e: pass @@ -96,6 +161,10 @@ class LeastBusyLoggingHandler(CustomLogger): self.router_cache.set_cache( key=request_count_api_key, value=request_count_dict ) + + ### TESTING ### + if self.test_flag: + self.logged_failure += 1 except Exception as e: pass diff --git a/litellm/tests/test_least_busy_routing.py b/litellm/tests/test_least_busy_routing.py index 8c818ff4c..bd0855ed7 100644 --- a/litellm/tests/test_least_busy_routing.py +++ b/litellm/tests/test_least_busy_routing.py @@ -110,6 +110,8 @@ def test_router_get_available_deployments(): num_retries=3, ) # type: ignore + router.leastbusy_logger.test_flag = True + model_group = "azure-model" deployment = "azure/chatgpt-v-2" request_count_dict = {1: 10, 2: 54, 3: 100} @@ -120,15 +122,19 @@ def test_router_get_available_deployments(): print(f"deployment: {deployment}") assert deployment["model_info"]["id"] == 1 - ## run router completion - assert that the least-busy deployment was incremented + ## run router completion - assert completion event, no change in 'busy'ness once calls are complete router.completion( model=model_group, messages=[{"role": "user", "content": "Hey, how's it going?"}], ) - least_busy_dict = router.cache.get_cache(key=cache_key) - assert least_busy_dict[1] == 11 + return_dict = router.cache.get_cache(key=cache_key) + + assert router.leastbusy_logger.logged_success == 1 + assert return_dict[1] == 10 + assert return_dict[2] == 54 + assert return_dict[3] == 100 -# test_router_get_available_deployments() +test_router_get_available_deployments()