fix(router.py): fix least-busy routing

This commit is contained in:
Krrish Dholakia 2023-12-08 20:29:37 -08:00
parent d9b115b8fb
commit 4bf875d3ed
8 changed files with 292 additions and 31 deletions

View file

@ -9,14 +9,14 @@
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback
import random, threading, time, traceback, uuid
import litellm, openai
from litellm.caching import RedisCache, InMemoryCache, DualCache
import logging, asyncio
import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
class Router:
"""
Example usage:
@ -57,6 +57,7 @@ class Router:
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
num_retries: int = 0
tenacity = None
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
def __init__(self,
model_list: Optional[list] = None,
@ -98,7 +99,7 @@ class Router:
self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model
self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model
self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call)
# make Router.chat.completions.create compatible for openai.chat.completions.create
self.chat = litellm.Chat(params=default_litellm_params)
@ -107,10 +108,6 @@ class Router:
self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params.setdefault("max_retries", 0)
### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy":
self._start_health_check_thread()
### CACHING ###
cache_type = "local" # default to an in-memory cache
redis_cache = None
@ -137,6 +134,16 @@ class Router:
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
self.cache_responses = cache_responses
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### ROUTING SETUP ###
if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
## add callback
if isinstance(litellm.input_callback, list):
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
else:
litellm.input_callback = [self.leastbusy_logger] # type: ignore
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
## USAGE TRACKING ##
if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.deployment_callback)
@ -664,6 +671,7 @@ class Router:
return kwargs
except Exception as e:
raise e
def _set_cooldown_deployments(self,
deployment: str):
"""
@ -873,6 +881,10 @@ class Router:
for model in self.model_list:
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
#### MODEL ID INIT ########
model_info = model.get("model_info", {})
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
model["model_info"] = model_info
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider is None:
@ -1119,8 +1131,8 @@ class Router:
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
deployments_to_remove = []
cooldown_deployments = self._get_cooldown_deployments()
@ -1140,13 +1152,24 @@ class Router:
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
if self.routing_strategy == "least-busy":
if len(self.healthy_deployments) > 0:
for item in self.healthy_deployments:
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
return item[0]
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployments = self.leastbusy_logger.get_available_deployments(model_group=model)
# pick least busy deployment
min_traffic = float('inf')
min_deployment = None
for k, v in deployments.items():
if v < min_traffic:
min_deployment = k
############## No Available Deployments passed, we do a random pick #################
if min_deployment is None:
min_deployment = random.choice(healthy_deployments)
############## Available Deployments passed, we find the relevant item #################
else:
raise ValueError("No models available.")
for m in healthy_deployments:
if m["model_info"]["id"] == min_deployment:
return m
min_deployment = random.choice(healthy_deployments)
return min_deployment
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick #################