mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(router.py): fix least-busy routing
This commit is contained in:
parent
d9b115b8fb
commit
4bf875d3ed
8 changed files with 292 additions and 31 deletions
|
@ -9,14 +9,14 @@
|
|||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any
|
||||
import random, threading, time, traceback
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
import logging, asyncio
|
||||
import inspect, concurrent
|
||||
from openai import AsyncOpenAI
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
class Router:
|
||||
"""
|
||||
Example usage:
|
||||
|
@ -57,6 +57,7 @@ class Router:
|
|||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
num_retries: int = 0
|
||||
tenacity = None
|
||||
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
|
||||
|
||||
def __init__(self,
|
||||
model_list: Optional[list] = None,
|
||||
|
@ -98,7 +99,7 @@ class Router:
|
|||
self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model
|
||||
self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model
|
||||
self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call)
|
||||
|
||||
|
||||
# make Router.chat.completions.create compatible for openai.chat.completions.create
|
||||
self.chat = litellm.Chat(params=default_litellm_params)
|
||||
|
||||
|
@ -107,10 +108,6 @@ class Router:
|
|||
self.default_litellm_params.setdefault("timeout", timeout)
|
||||
self.default_litellm_params.setdefault("max_retries", 0)
|
||||
|
||||
|
||||
### HEALTH CHECK THREAD ###
|
||||
if self.routing_strategy == "least-busy":
|
||||
self._start_health_check_thread()
|
||||
### CACHING ###
|
||||
cache_type = "local" # default to an in-memory cache
|
||||
redis_cache = None
|
||||
|
@ -137,6 +134,16 @@ class Router:
|
|||
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
|
||||
self.cache_responses = cache_responses
|
||||
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
### ROUTING SETUP ###
|
||||
if routing_strategy == "least-busy":
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
|
||||
## add callback
|
||||
if isinstance(litellm.input_callback, list):
|
||||
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
|
||||
else:
|
||||
litellm.input_callback = [self.leastbusy_logger] # type: ignore
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
|
||||
## USAGE TRACKING ##
|
||||
if isinstance(litellm.success_callback, list):
|
||||
litellm.success_callback.append(self.deployment_callback)
|
||||
|
@ -664,6 +671,7 @@ class Router:
|
|||
return kwargs
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _set_cooldown_deployments(self,
|
||||
deployment: str):
|
||||
"""
|
||||
|
@ -873,6 +881,10 @@ class Router:
|
|||
for model in self.model_list:
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_name = litellm_params.get("model")
|
||||
#### MODEL ID INIT ########
|
||||
model_info = model.get("model_info", {})
|
||||
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
||||
model["model_info"] = model_info
|
||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider is None:
|
||||
|
@ -1119,8 +1131,8 @@ class Router:
|
|||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||
if len(healthy_deployments) == 0:
|
||||
# check if the user sent in a deployment name instead
|
||||
|
||||
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||
|
||||
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
deployments_to_remove = []
|
||||
cooldown_deployments = self._get_cooldown_deployments()
|
||||
|
@ -1140,13 +1152,24 @@ class Router:
|
|||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
if self.routing_strategy == "least-busy":
|
||||
if len(self.healthy_deployments) > 0:
|
||||
for item in self.healthy_deployments:
|
||||
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
||||
return item[0]
|
||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||
deployments = self.leastbusy_logger.get_available_deployments(model_group=model)
|
||||
# pick least busy deployment
|
||||
min_traffic = float('inf')
|
||||
min_deployment = None
|
||||
for k, v in deployments.items():
|
||||
if v < min_traffic:
|
||||
min_deployment = k
|
||||
############## No Available Deployments passed, we do a random pick #################
|
||||
if min_deployment is None:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
############## Available Deployments passed, we find the relevant item #################
|
||||
else:
|
||||
raise ValueError("No models available.")
|
||||
for m in healthy_deployments:
|
||||
if m["model_info"]["id"] == min_deployment:
|
||||
return m
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
elif self.routing_strategy == "simple-shuffle":
|
||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue