mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(router.py): fix least-busy routing
This commit is contained in:
parent
d9b115b8fb
commit
4bf875d3ed
8 changed files with 292 additions and 31 deletions
|
@ -9,6 +9,7 @@ input_callback: List[Union[str, Callable]] = []
|
|||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
callbacks: List[Callable] = []
|
||||
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
pre_call_rules: List[Callable] = []
|
||||
|
|
|
@ -61,8 +61,6 @@ class InMemoryCache(BaseCache):
|
|||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
if isinstance(cached_response, dict):
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
|
@ -110,8 +108,6 @@ class RedisCache(BaseCache):
|
|||
cached_response = json.loads(cached_response) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
if isinstance(cached_response, dict):
|
||||
cached_response['cache'] = True # set cache-hit flag to True
|
||||
return cached_response
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
|
|
|
@ -28,6 +28,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
|
|||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
|
@ -51,6 +54,22 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
|
|||
traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["log_event_type"] = "pre_api_call"
|
||||
await callback_func(
|
||||
kwargs,
|
||||
)
|
||||
print_verbose(
|
||||
f"Custom Logger - model call details: {kwargs}"
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
|
||||
# Method definition
|
||||
try:
|
||||
|
|
|
@ -9,14 +9,14 @@
|
|||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any
|
||||
import random, threading, time, traceback
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
import logging, asyncio
|
||||
import inspect, concurrent
|
||||
from openai import AsyncOpenAI
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
class Router:
|
||||
"""
|
||||
Example usage:
|
||||
|
@ -57,6 +57,7 @@ class Router:
|
|||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
num_retries: int = 0
|
||||
tenacity = None
|
||||
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
|
||||
|
||||
def __init__(self,
|
||||
model_list: Optional[list] = None,
|
||||
|
@ -107,10 +108,6 @@ class Router:
|
|||
self.default_litellm_params.setdefault("timeout", timeout)
|
||||
self.default_litellm_params.setdefault("max_retries", 0)
|
||||
|
||||
|
||||
### HEALTH CHECK THREAD ###
|
||||
if self.routing_strategy == "least-busy":
|
||||
self._start_health_check_thread()
|
||||
### CACHING ###
|
||||
cache_type = "local" # default to an in-memory cache
|
||||
redis_cache = None
|
||||
|
@ -137,6 +134,16 @@ class Router:
|
|||
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
|
||||
self.cache_responses = cache_responses
|
||||
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
### ROUTING SETUP ###
|
||||
if routing_strategy == "least-busy":
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
|
||||
## add callback
|
||||
if isinstance(litellm.input_callback, list):
|
||||
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
|
||||
else:
|
||||
litellm.input_callback = [self.leastbusy_logger] # type: ignore
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
|
||||
## USAGE TRACKING ##
|
||||
if isinstance(litellm.success_callback, list):
|
||||
litellm.success_callback.append(self.deployment_callback)
|
||||
|
@ -664,6 +671,7 @@ class Router:
|
|||
return kwargs
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _set_cooldown_deployments(self,
|
||||
deployment: str):
|
||||
"""
|
||||
|
@ -873,6 +881,10 @@ class Router:
|
|||
for model in self.model_list:
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_name = litellm_params.get("model")
|
||||
#### MODEL ID INIT ########
|
||||
model_info = model.get("model_info", {})
|
||||
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
||||
model["model_info"] = model_info
|
||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider is None:
|
||||
|
@ -1119,8 +1131,8 @@ class Router:
|
|||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||
if len(healthy_deployments) == 0:
|
||||
# check if the user sent in a deployment name instead
|
||||
|
||||
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||
|
||||
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
deployments_to_remove = []
|
||||
cooldown_deployments = self._get_cooldown_deployments()
|
||||
|
@ -1140,13 +1152,24 @@ class Router:
|
|||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
if self.routing_strategy == "least-busy":
|
||||
if len(self.healthy_deployments) > 0:
|
||||
for item in self.healthy_deployments:
|
||||
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
||||
return item[0]
|
||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||
deployments = self.leastbusy_logger.get_available_deployments(model_group=model)
|
||||
# pick least busy deployment
|
||||
min_traffic = float('inf')
|
||||
min_deployment = None
|
||||
for k, v in deployments.items():
|
||||
if v < min_traffic:
|
||||
min_deployment = k
|
||||
############## No Available Deployments passed, we do a random pick #################
|
||||
if min_deployment is None:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
############## Available Deployments passed, we find the relevant item #################
|
||||
else:
|
||||
raise ValueError("No models available.")
|
||||
for m in healthy_deployments:
|
||||
if m["model_info"]["id"] == min_deployment:
|
||||
return m
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
elif self.routing_strategy == "simple-shuffle":
|
||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
|
|
96
litellm/router_strategy/least_busy.py
Normal file
96
litellm/router_strategy/least_busy.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
#### What this does ####
|
||||
# identifies least busy deployment
|
||||
# How is this achieved?
|
||||
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||
# - use litellm.success + failure callbacks to log when a request completed
|
||||
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
|
||||
import dotenv, os, requests
|
||||
from typing import Optional
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
class LeastBusyLoggingHandler(CustomLogger):
|
||||
|
||||
def __init__(self, router_cache: DualCache):
|
||||
self.router_cache = router_cache
|
||||
self.mapping_deployment_to_id: dict = {}
|
||||
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
"""
|
||||
Log when a model is being used.
|
||||
|
||||
Caching based on model group.
|
||||
"""
|
||||
try:
|
||||
|
||||
if kwargs['litellm_params'].get('metadata') is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||
id = kwargs['litellm_params'].get('model_info', {}).get('id', None)
|
||||
if deployment is None or model_group is None or id is None:
|
||||
return
|
||||
|
||||
# map deployment to id
|
||||
self.mapping_deployment_to_id[deployment] = id
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# update cache
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1
|
||||
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs['litellm_params'].get('metadata') is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||
if deployment is None or model_group is None:
|
||||
return
|
||||
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs['litellm_params'].get('metadata') is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||
if deployment is None or model_group is None:
|
||||
return
|
||||
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def get_available_deployments(self, model_group: str):
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
# map deployment to id
|
||||
return_dict = {}
|
||||
for key, value in request_count_dict.items():
|
||||
return_dict[self.mapping_deployment_to_id[key]] = value
|
||||
return return_dict
|
79
litellm/tests/test_least_busy_routing.py
Normal file
79
litellm/tests/test_least_busy_routing.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
# #### What this tests ####
|
||||
# # This tests the router's ability to identify the least busy deployment
|
||||
|
||||
# #
|
||||
# # How is this achieved?
|
||||
# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||
# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||
# # - use litellm.success + failure callbacks to log when a request completed
|
||||
# # - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
|
||||
# import sys, os, asyncio, time
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
|
||||
# load_dotenv()
|
||||
# import os
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import pytest
|
||||
# from litellm import Router
|
||||
# import litellm
|
||||
|
||||
# async def test_least_busy_routing():
|
||||
# model_list = [{
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-turbo",
|
||||
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
# "api_base": "https://openai-france-1234.openai.azure.com",
|
||||
# "rpm": 1440,
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
# "rpm": 6
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
# "rpm": 6
|
||||
# }
|
||||
# }]
|
||||
# router = Router(model_list=model_list,
|
||||
# routing_strategy="least-busy",
|
||||
# set_verbose=False,
|
||||
# num_retries=3) # type: ignore
|
||||
|
||||
# async def call_azure_completion():
|
||||
# try:
|
||||
# response = await router.acompletion(
|
||||
# model="azure-model",
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "hello this request will pass"
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# print("\n response", response)
|
||||
# return response
|
||||
# except:
|
||||
# return None
|
||||
|
||||
# n = 1000
|
||||
# start_time = time.time()
|
||||
# tasks = [call_azure_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# print(n, time.time() - start_time, len(successful_completions))
|
||||
|
||||
# asyncio.run(test_least_busy_routing())
|
|
@ -226,10 +226,12 @@ def test_call_one_endpoint():
|
|||
)
|
||||
|
||||
print("\n response", response)
|
||||
tasks = []
|
||||
tasks.append(call_azure_completion())
|
||||
tasks.append(call_bedrock_claude())
|
||||
tasks.append(call_azure_embedding())
|
||||
|
||||
asyncio.run(call_azure_completion())
|
||||
asyncio.run(call_bedrock_claude())
|
||||
asyncio.run(call_azure_embedding())
|
||||
asyncio.gather(**tasks)
|
||||
|
||||
os.environ["AZURE_API_BASE"] = old_api_base
|
||||
os.environ["AZURE_API_KEY"] = old_api_key
|
||||
|
|
|
@ -562,19 +562,25 @@ class Logging:
|
|||
**self.optional_params
|
||||
}
|
||||
|
||||
def _pre_call(self, input, api_key, model=None, additional_args={}):
|
||||
"""
|
||||
Common helper function across the sync + async pre-call function
|
||||
"""
|
||||
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
|
||||
self.model_call_details["input"] = input
|
||||
self.model_call_details["api_key"] = api_key
|
||||
self.model_call_details["additional_args"] = additional_args
|
||||
self.model_call_details["log_event_type"] = "pre_api_call"
|
||||
if (
|
||||
model
|
||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
|
||||
def pre_call(self, input, api_key, model=None, additional_args={}):
|
||||
# Log the exact input to the LLM API
|
||||
litellm.error_logs['PRE_CALL'] = locals()
|
||||
try:
|
||||
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
|
||||
self.model_call_details["input"] = input
|
||||
self.model_call_details["api_key"] = api_key
|
||||
self.model_call_details["additional_args"] = additional_args
|
||||
self.model_call_details["log_event_type"] = "pre_api_call"
|
||||
if (
|
||||
model
|
||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
self._pre_call(input=input, api_key=api_key, model=model, additional_args=additional_args)
|
||||
|
||||
# User Logging -> if you pass in a custom logging function
|
||||
headers = additional_args.get("headers", {})
|
||||
|
@ -688,6 +694,34 @@ class Logging:
|
|||
if capture_exception: # log this error to sentry for debugging
|
||||
capture_exception(e)
|
||||
|
||||
async def async_pre_call(self, result=None, start_time=None, end_time=None, **kwargs):
|
||||
"""
|
||||
 Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||
"""
|
||||
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
|
||||
print_verbose(f"Async input callbacks: {litellm._async_input_callback}")
|
||||
for callback in litellm._async_input_callback:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
print_verbose(f"Async input callbacks: CustomLogger")
|
||||
asyncio.create_task(callback.async_log_input_event(
|
||||
model=self.model,
|
||||
messages=self.messages,
|
||||
kwargs=self.model_call_details,
|
||||
))
|
||||
if callable(callback): # custom logger functions
|
||||
print_verbose(f"Async success callbacks: async_log_event")
|
||||
asyncio.create_task(customLogger.async_log_input_event(
|
||||
model=self.model,
|
||||
messages=self.messages,
|
||||
kwargs=self.model_call_details,
|
||||
print_verbose=print_verbose,
|
||||
callback_func=callback
|
||||
))
|
||||
except:
|
||||
print_verbose(
|
||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
||||
)
|
||||
def post_call(self, original_response, input=None, api_key=None, additional_args={}):
|
||||
# Log the exact result from the LLM API, for streaming - log the type of response received
|
||||
litellm.error_logs['POST_CALL'] = locals()
|
||||
|
@ -1289,6 +1323,17 @@ def client(original_function):
|
|||
function_id=function_id
|
||||
)
|
||||
## ASYNC CALLBACKS
|
||||
if len(litellm.input_callback) > 0:
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.input_callback):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
litellm._async_input_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
|
||||
# Pop the async items from input_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
litellm.input_callback.pop(index)
|
||||
|
||||
if len(litellm.success_callback) > 0:
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.success_callback):
|
||||
|
@ -1307,7 +1352,7 @@ def client(original_function):
|
|||
litellm._async_failure_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
|
||||
# Pop the async items from success_callback in reverse order to avoid index issues
|
||||
# Pop the async items from failure_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
litellm.failure_callback.pop(index)
|
||||
if add_breadcrumb:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue