feat - make lowest_cost pure async

This commit is contained in:
Ishaan Jaff 2024-05-07 13:51:50 -07:00
parent 8e5437c8e9
commit 6983e7a84f
2 changed files with 26 additions and 16 deletions

View file

@ -2958,6 +2958,7 @@ class Router:
if ( if (
self.routing_strategy != "usage-based-routing-v2" self.routing_strategy != "usage-based-routing-v2"
and self.routing_strategy != "simple-shuffle" and self.routing_strategy != "simple-shuffle"
and self.routing_strategy != "cost-based-routing"
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
return self.get_available_deployment( return self.get_available_deployment(
model=model, model=model,
@ -3014,6 +3015,16 @@ class Router:
messages=messages, messages=messages,
input=input, input=input,
) )
if (
self.routing_strategy == "cost-based-routing"
and self.lowestcost_logger is not None
):
deployment = await self.lowestcost_logger.async_get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments,
messages=messages,
input=input,
)
elif self.routing_strategy == "simple-shuffle": elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick ################# ############## Check if we can do a RPM/TPM based weighted pick #################
@ -3184,15 +3195,6 @@ class Router:
messages=messages, messages=messages,
input=input, input=input,
) )
elif (
self.routing_strategy == "cost-based-routing"
and self.lowestcost_logger is not None
):
deployment = self.lowestcost_logger.get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
)
if deployment is None: if deployment is None:
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"

View file

@ -40,7 +40,7 @@ class LowestCostLoggingHandler(CustomLogger):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
def log_success_event(self, kwargs, response_obj, start_time, end_time): async def log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
""" """
Update usage on success Update usage on success
@ -90,7 +90,11 @@ class LowestCostLoggingHandler(CustomLogger):
# Update usage # Update usage
# ------------ # ------------
request_count_dict = self.router_cache.get_cache(key=cost_key) or {} request_count_dict = (
await self.router_cache.async_get_cache(key=cost_key) or {}
)
# check local result first
if id not in request_count_dict: if id not in request_count_dict:
request_count_dict[id] = {} request_count_dict[id] = {}
@ -111,7 +115,9 @@ class LowestCostLoggingHandler(CustomLogger):
request_count_dict[id][precise_minute].get("rpm", 0) + 1 request_count_dict[id][precise_minute].get("rpm", 0) + 1
) )
self.router_cache.set_cache(key=cost_key, value=request_count_dict) await self.router_cache.async_set_cache(
key=cost_key, value=request_count_dict
)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:
@ -172,7 +178,9 @@ class LowestCostLoggingHandler(CustomLogger):
# Update usage # Update usage
# ------------ # ------------
request_count_dict = self.router_cache.get_cache(key=cost_key) or {} request_count_dict = (
await self.router_cache.async_get_cache(key=cost_key) or {}
)
if id not in request_count_dict: if id not in request_count_dict:
request_count_dict[id] = {} request_count_dict[id] = {}
@ -189,7 +197,7 @@ class LowestCostLoggingHandler(CustomLogger):
request_count_dict[id][precise_minute].get("rpm", 0) + 1 request_count_dict[id][precise_minute].get("rpm", 0) + 1
) )
self.router_cache.set_cache( await self.router_cache.async_set_cache(
key=cost_key, value=request_count_dict key=cost_key, value=request_count_dict
) # reset map within window ) # reset map within window
@ -200,7 +208,7 @@ class LowestCostLoggingHandler(CustomLogger):
traceback.print_exc() traceback.print_exc()
pass pass
def get_available_deployments( async def async_get_available_deployments(
self, self,
model_group: str, model_group: str,
healthy_deployments: list, healthy_deployments: list,
@ -213,7 +221,7 @@ class LowestCostLoggingHandler(CustomLogger):
""" """
cost_key = f"{model_group}_map" cost_key = f"{model_group}_map"
request_count_dict = self.router_cache.get_cache(key=cost_key) or {} request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {}
# ----------------------- # -----------------------
# Find lowest used model # Find lowest used model