fix(router.py): fix update routing strategy

This commit is contained in:
Krrish Dholakia 2024-05-01 09:50:45 -07:00
parent b3a788142b
commit 1ad67a0d75

View file

@ -290,6 +290,21 @@ class Router:
} }
""" """
### ROUTING SETUP ### ### ROUTING SETUP ###
self.routing_strategy_init(
routing_strategy=routing_strategy,
routing_strategy_args=routing_strategy_args,
)
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
else:
litellm.failure_callback = [self.deployment_callback_on_failure]
print( # noqa
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}"
) # noqa
self.routing_strategy_args = routing_strategy_args
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy": if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler( self.leastbusy_logger = LeastBusyLoggingHandler(
router_cache=self.cache, model_list=self.model_list router_cache=self.cache, model_list=self.model_list
@ -321,15 +336,6 @@ class Router:
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
else:
litellm.failure_callback = [self.deployment_callback_on_failure]
print( # noqa
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}"
) # noqa
self.routing_strategy_args = routing_strategy_args
def print_deployment(self, deployment: dict): def print_deployment(self, deployment: dict):
""" """
@ -2659,6 +2665,13 @@ class Router:
_casted_value = int(kwargs[var]) _casted_value = int(kwargs[var])
setattr(self, var, _casted_value) setattr(self, var, _casted_value)
else: else:
if var == "routing_strategy":
self.routing_strategy_init(
routing_strategy=kwargs[var],
routing_strategy_args=kwargs.get(
"routing_strategy_args", {}
),
)
setattr(self, var, kwargs[var]) setattr(self, var, kwargs[var])
else: else:
verbose_router_logger.debug("Setting {} is not allowed".format(var)) verbose_router_logger.debug("Setting {} is not allowed".format(var))