test(test_least_busy_router.py): add better testing for least busy routing

This commit is contained in:
Krrish Dholakia 2023-12-29 17:16:00 +05:30
parent 678bbfa9be
commit 54d7bc2cc3
2 changed files with 79 additions and 4 deletions

View file

@ -16,6 +16,10 @@ from litellm.integrations.custom_logger import CustomLogger
class LeastBusyLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
def __init__(self, router_cache: DualCache):
self.router_cache = router_cache
self.mapping_deployment_to_id: dict = {}
@ -50,6 +54,63 @@ class LeastBusyLoggingHandler(CustomLogger):
except Exception as e:
pass
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[id] = request_count_dict.get(id) - 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
pass
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[id] = request_count_dict.get(id) - 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_failure += 1
except Exception as e:
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
@ -72,6 +133,10 @@ class LeastBusyLoggingHandler(CustomLogger):
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
pass
@ -96,6 +161,10 @@ class LeastBusyLoggingHandler(CustomLogger):
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_failure += 1
except Exception as e:
pass

View file

@ -110,6 +110,8 @@ def test_router_get_available_deployments():
num_retries=3,
) # type: ignore
router.leastbusy_logger.test_flag = True
model_group = "azure-model"
deployment = "azure/chatgpt-v-2"
request_count_dict = {1: 10, 2: 54, 3: 100}
@ -120,15 +122,19 @@ def test_router_get_available_deployments():
print(f"deployment: {deployment}")
assert deployment["model_info"]["id"] == 1
## run router completion - assert that the least-busy deployment was incremented
## run router completion - assert completion event, no change in 'busy'ness once calls are complete
router.completion(
model=model_group,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
least_busy_dict = router.cache.get_cache(key=cache_key)
assert least_busy_dict[1] == 11
return_dict = router.cache.get_cache(key=cache_key)
assert router.leastbusy_logger.logged_success == 1
assert return_dict[1] == 10
assert return_dict[2] == 54
assert return_dict[3] == 100
# test_router_get_available_deployments()
test_router_get_available_deployments()