forked from phoenix/litellm-mirror
test(test_least_busy_router.py): add better testing for least busy routing
This commit is contained in:
parent
678bbfa9be
commit
54d7bc2cc3
2 changed files with 79 additions and 4 deletions
|
@ -16,6 +16,10 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
|
||||||
class LeastBusyLoggingHandler(CustomLogger):
|
class LeastBusyLoggingHandler(CustomLogger):
|
||||||
|
test_flag: bool = False
|
||||||
|
logged_success: int = 0
|
||||||
|
logged_failure: int = 0
|
||||||
|
|
||||||
def __init__(self, router_cache: DualCache):
|
def __init__(self, router_cache: DualCache):
|
||||||
self.router_cache = router_cache
|
self.router_cache = router_cache
|
||||||
self.mapping_deployment_to_id: dict = {}
|
self.mapping_deployment_to_id: dict = {}
|
||||||
|
@ -50,6 +54,63 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
# decrement count in cache
|
||||||
|
request_count_dict = (
|
||||||
|
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
|
)
|
||||||
|
request_count_dict[id] = request_count_dict.get(id) - 1
|
||||||
|
self.router_cache.set_cache(
|
||||||
|
key=request_count_api_key, value=request_count_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_success += 1
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
# decrement count in cache
|
||||||
|
request_count_dict = (
|
||||||
|
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
|
)
|
||||||
|
request_count_dict[id] = request_count_dict.get(id) - 1
|
||||||
|
self.router_cache.set_cache(
|
||||||
|
key=request_count_api_key, value=request_count_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_failure += 1
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
if kwargs["litellm_params"].get("metadata") is None:
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
@ -72,6 +133,10 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
self.router_cache.set_cache(
|
self.router_cache.set_cache(
|
||||||
key=request_count_api_key, value=request_count_dict
|
key=request_count_api_key, value=request_count_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_success += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -96,6 +161,10 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
self.router_cache.set_cache(
|
self.router_cache.set_cache(
|
||||||
key=request_count_api_key, value=request_count_dict
|
key=request_count_api_key, value=request_count_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_failure += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -110,6 +110,8 @@ def test_router_get_available_deployments():
|
||||||
num_retries=3,
|
num_retries=3,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
router.leastbusy_logger.test_flag = True
|
||||||
|
|
||||||
model_group = "azure-model"
|
model_group = "azure-model"
|
||||||
deployment = "azure/chatgpt-v-2"
|
deployment = "azure/chatgpt-v-2"
|
||||||
request_count_dict = {1: 10, 2: 54, 3: 100}
|
request_count_dict = {1: 10, 2: 54, 3: 100}
|
||||||
|
@ -120,15 +122,19 @@ def test_router_get_available_deployments():
|
||||||
print(f"deployment: {deployment}")
|
print(f"deployment: {deployment}")
|
||||||
assert deployment["model_info"]["id"] == 1
|
assert deployment["model_info"]["id"] == 1
|
||||||
|
|
||||||
## run router completion - assert that the least-busy deployment was incremented
|
## run router completion - assert completion event, no change in 'busy'ness once calls are complete
|
||||||
|
|
||||||
router.completion(
|
router.completion(
|
||||||
model=model_group,
|
model=model_group,
|
||||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
least_busy_dict = router.cache.get_cache(key=cache_key)
|
return_dict = router.cache.get_cache(key=cache_key)
|
||||||
assert least_busy_dict[1] == 11
|
|
||||||
|
assert router.leastbusy_logger.logged_success == 1
|
||||||
|
assert return_dict[1] == 10
|
||||||
|
assert return_dict[2] == 54
|
||||||
|
assert return_dict[3] == 100
|
||||||
|
|
||||||
|
|
||||||
# test_router_get_available_deployments()
|
test_router_get_available_deployments()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue