fix(router.py): fix least-busy routing

This commit is contained in:
Krrish Dholakia 2023-12-08 20:29:37 -08:00
parent d8e60d7290
commit a65c8919fc
8 changed files with 292 additions and 31 deletions

View file

@ -9,6 +9,7 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []

View file

@ -61,8 +61,6 @@ class InMemoryCache(BaseCache):
cached_response = json.loads(original_cached_response)
except:
cached_response = original_cached_response
if isinstance(cached_response, dict):
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
return None
@ -110,8 +108,6 @@ class RedisCache(BaseCache):
cached_response = json.loads(cached_response) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
if isinstance(cached_response, dict):
cached_response['cache'] = True # set cache-hit flag to True
return cached_response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception

View file

@ -27,6 +27,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_pre_api_call(self, model, messages, kwargs):
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
@ -51,6 +54,22 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
try:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["log_event_type"] = "pre_api_call"
await callback_func(
kwargs,
)
print_verbose(
f"Custom Logger - model call details: {kwargs}"
)
except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
# Method definition
try:

View file

@ -9,14 +9,14 @@
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback
import random, threading, time, traceback, uuid
import litellm, openai
from litellm.caching import RedisCache, InMemoryCache, DualCache
import logging, asyncio
import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
class Router:
"""
Example usage:
@ -57,6 +57,7 @@ class Router:
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
num_retries: int = 0
tenacity = None
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
def __init__(self,
model_list: Optional[list] = None,
@ -98,7 +99,7 @@ class Router:
self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model
self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model
self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call)
# make Router.chat.completions.create compatible for openai.chat.completions.create
self.chat = litellm.Chat(params=default_litellm_params)
@ -107,10 +108,6 @@ class Router:
self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params.setdefault("max_retries", 0)
### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy":
self._start_health_check_thread()
### CACHING ###
cache_type = "local" # default to an in-memory cache
redis_cache = None
@ -137,6 +134,16 @@ class Router:
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
self.cache_responses = cache_responses
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### ROUTING SETUP ###
if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
## add callback
if isinstance(litellm.input_callback, list):
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
else:
litellm.input_callback = [self.leastbusy_logger] # type: ignore
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
## USAGE TRACKING ##
if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.deployment_callback)
@ -664,6 +671,7 @@ class Router:
return kwargs
except Exception as e:
raise e
def _set_cooldown_deployments(self,
deployment: str):
"""
@ -873,6 +881,10 @@ class Router:
for model in self.model_list:
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
#### MODEL ID INIT ########
model_info = model.get("model_info", {})
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
model["model_info"] = model_info
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider is None:
@ -1119,8 +1131,8 @@ class Router:
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
deployments_to_remove = []
cooldown_deployments = self._get_cooldown_deployments()
@ -1140,13 +1152,24 @@ class Router:
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
if self.routing_strategy == "least-busy":
if len(self.healthy_deployments) > 0:
for item in self.healthy_deployments:
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
return item[0]
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployments = self.leastbusy_logger.get_available_deployments(model_group=model)
# pick least busy deployment
min_traffic = float('inf')
min_deployment = None
for k, v in deployments.items():
if v < min_traffic:
min_deployment = k
############## No Available Deployments passed, we do a random pick #################
if min_deployment is None:
min_deployment = random.choice(healthy_deployments)
############## Available Deployments passed, we find the relevant item #################
else:
raise ValueError("No models available.")
for m in healthy_deployments:
if m["model_info"]["id"] == min_deployment:
return m
min_deployment = random.choice(healthy_deployments)
return min_deployment
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick #################

View file

@ -0,0 +1,96 @@
#### What this does ####
# identifies least busy deployment
# How is this achieved?
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic
import dotenv, os, requests
from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
class LeastBusyLoggingHandler(CustomLogger):
def __init__(self, router_cache: DualCache):
self.router_cache = router_cache
self.mapping_deployment_to_id: dict = {}
def log_pre_api_call(self, model, messages, kwargs):
"""
Log when a model is being used.
Caching based on model group.
"""
try:
if kwargs['litellm_params'].get('metadata') is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
id = kwargs['litellm_params'].get('model_info', {}).get('id', None)
if deployment is None or model_group is None or id is None:
return
# map deployment to id
self.mapping_deployment_to_id[deployment] = id
request_count_api_key = f"{model_group}_request_count"
# update cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
except Exception as e:
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs['litellm_params'].get('metadata') is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
if deployment is None or model_group is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
except Exception as e:
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs['litellm_params'].get('metadata') is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
if deployment is None or model_group is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
except Exception as e:
pass
def get_available_deployments(self, model_group: str):
request_count_api_key = f"{model_group}_request_count"
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
# map deployment to id
return_dict = {}
for key, value in request_count_dict.items():
return_dict[self.mapping_deployment_to_id[key]] = value
return return_dict

View file

@ -0,0 +1,79 @@
# #### What this tests ####
# # This tests the router's ability to identify the least busy deployment
# #
# # How is this achieved?
# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# # - use litellm.success + failure callbacks to log when a request completed
# # - in get_available_deployment, for a given model group name -> pick based on traffic
# import sys, os, asyncio, time
# import traceback
# from dotenv import load_dotenv
# load_dotenv()
# import os
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# from litellm import Router
# import litellm
# async def test_least_busy_routing():
# model_list = [{
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-turbo",
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
# "api_base": "https://openai-france-1234.openai.azure.com",
# "rpm": 1440,
# }
# }, {
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
# "rpm": 6
# }
# }, {
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
# "rpm": 6
# }
# }]
# router = Router(model_list=model_list,
# routing_strategy="least-busy",
# set_verbose=False,
# num_retries=3) # type: ignore
# async def call_azure_completion():
# try:
# response = await router.acompletion(
# model="azure-model",
# messages=[
# {
# "role": "user",
# "content": "hello this request will pass"
# }
# ]
# )
# print("\n response", response)
# return response
# except:
# return None
# n = 1000
# start_time = time.time()
# tasks = [call_azure_completion() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks)
# successful_completions = [c for c in chat_completions if c is not None]
# print(n, time.time() - start_time, len(successful_completions))
# asyncio.run(test_least_busy_routing())

View file

@ -226,10 +226,12 @@ def test_call_one_endpoint():
)
print("\n response", response)
tasks = []
tasks.append(call_azure_completion())
tasks.append(call_bedrock_claude())
tasks.append(call_azure_embedding())
asyncio.run(call_azure_completion())
asyncio.run(call_bedrock_claude())
asyncio.run(call_azure_embedding())
asyncio.gather(**tasks)
os.environ["AZURE_API_BASE"] = old_api_base
os.environ["AZURE_API_KEY"] = old_api_key

View file

@ -562,19 +562,25 @@ class Logging:
**self.optional_params
}
def _pre_call(self, input, api_key, model=None, additional_args={}):
"""
Common helper function across the sync + async pre-call function
"""
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args
self.model_call_details["log_event_type"] = "pre_api_call"
if (
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
def pre_call(self, input, api_key, model=None, additional_args={}):
# Log the exact input to the LLM API
litellm.error_logs['PRE_CALL'] = locals()
try:
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args
self.model_call_details["log_event_type"] = "pre_api_call"
if (
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
self._pre_call(input=input, api_key=api_key, model=model, additional_args=additional_args)
# User Logging -> if you pass in a custom logging function
headers = additional_args.get("headers", {})
@ -688,6 +694,34 @@ class Logging:
if capture_exception: # log this error to sentry for debugging
capture_exception(e)
async def async_pre_call(self, result=None, start_time=None, end_time=None, **kwargs):
"""
 Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
print_verbose(f"Async input callbacks: {litellm._async_input_callback}")
for callback in litellm._async_input_callback:
try:
if isinstance(callback, CustomLogger): # custom logger class
print_verbose(f"Async input callbacks: CustomLogger")
asyncio.create_task(callback.async_log_input_event(
model=self.model,
messages=self.messages,
kwargs=self.model_call_details,
))
if callable(callback): # custom logger functions
print_verbose(f"Async success callbacks: async_log_event")
asyncio.create_task(customLogger.async_log_input_event(
model=self.model,
messages=self.messages,
kwargs=self.model_call_details,
print_verbose=print_verbose,
callback_func=callback
))
except:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
def post_call(self, original_response, input=None, api_key=None, additional_args={}):
# Log the exact result from the LLM API, for streaming - log the type of response received
litellm.error_logs['POST_CALL'] = locals()
@ -1289,6 +1323,17 @@ def client(original_function):
function_id=function_id
)
## ASYNC CALLBACKS
if len(litellm.input_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.input_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_input_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from input_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.input_callback.pop(index)
if len(litellm.success_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.success_callback):
@ -1307,7 +1352,7 @@ def client(original_function):
litellm._async_failure_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
# Pop the async items from failure_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.failure_callback.pop(index)
if add_breadcrumb: