diff --git a/litellm/__init__.py b/litellm/__init__.py index 11757eab8..b7aeeb210 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -9,6 +9,7 @@ input_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] callbacks: List[Callable] = [] +_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] diff --git a/litellm/caching.py b/litellm/caching.py index c79f667df..bbad49716 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -61,8 +61,6 @@ class InMemoryCache(BaseCache): cached_response = json.loads(original_cached_response) except: cached_response = original_cached_response - if isinstance(cached_response, dict): - cached_response['cache'] = True # set cache-hit flag to True return cached_response return None @@ -110,8 +108,6 @@ class RedisCache(BaseCache): cached_response = json.loads(cached_response) # Convert string to dictionary except: cached_response = ast.literal_eval(cached_response) - if isinstance(cached_response, dict): - cached_response['cache'] = True # set cache-hit flag to True return cached_response except Exception as e: # NON blocking - notify users Redis is throwing an exception diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 586c90819..d0efc2fb0 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -27,6 +27,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback def log_failure_event(self, kwargs, response_obj, start_time, end_time): pass + + async def async_log_pre_api_call(self, model, messages, kwargs): + pass async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): pass @@ -51,6 +54,22 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback traceback.print_exc() print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func): + try: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["log_event_type"] = "pre_api_call" + await callback_func( + kwargs, + ) + print_verbose( + f"Custom Logger - model call details: {kwargs}" + ) + except: + traceback.print_exc() + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + + def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func): # Method definition try: diff --git a/litellm/router.py b/litellm/router.py index edbc3cd74..5a0cf8a37 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -9,14 +9,14 @@ from datetime import datetime from typing import Dict, List, Optional, Union, Literal, Any -import random, threading, time, traceback +import random, threading, time, traceback, uuid import litellm, openai from litellm.caching import RedisCache, InMemoryCache, DualCache import logging, asyncio import inspect, concurrent from openai import AsyncOpenAI from collections import defaultdict - +from litellm.router_strategy.least_busy import LeastBusyLoggingHandler class Router: """ Example usage: @@ -57,6 +57,7 @@ class Router: default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour num_retries: int = 0 tenacity = None + leastbusy_logger: Optional[LeastBusyLoggingHandler] = None def __init__(self, model_list: Optional[list] = None, @@ -98,7 +99,7 @@ class Router: self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call) - + # make Router.chat.completions.create compatible for openai.chat.completions.create self.chat = litellm.Chat(params=default_litellm_params) @@ -107,10 +108,6 @@ class Router: self.default_litellm_params.setdefault("timeout", timeout) self.default_litellm_params.setdefault("max_retries", 0) - - ### HEALTH CHECK THREAD ### - if self.routing_strategy == "least-busy": - self._start_health_check_thread() ### CACHING ### cache_type = "local" # default to an in-memory cache redis_cache = None @@ -137,6 +134,16 @@ class Router: litellm.cache = litellm.Cache(type=cache_type, **cache_config) self.cache_responses = cache_responses self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. + ### ROUTING SETUP ### + if routing_strategy == "least-busy": + self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache) + ## add callback + if isinstance(litellm.input_callback, list): + litellm.input_callback.append(self.leastbusy_logger) # type: ignore + else: + litellm.input_callback = [self.leastbusy_logger] # type: ignore + if isinstance(litellm.callbacks, list): + litellm.callbacks.append(self.leastbusy_logger) # type: ignore ## USAGE TRACKING ## if isinstance(litellm.success_callback, list): litellm.success_callback.append(self.deployment_callback) @@ -664,6 +671,7 @@ class Router: return kwargs except Exception as e: raise e + def _set_cooldown_deployments(self, deployment: str): """ @@ -873,6 +881,10 @@ class Router: for model in self.model_list: litellm_params = model.get("litellm_params", {}) model_name = litellm_params.get("model") + #### MODEL ID INIT ######## + model_info = model.get("model_info", {}) + model_info["id"] = model_info.get("id", str(uuid.uuid4())) + model["model_info"] = model_info #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## custom_llm_provider = litellm_params.get("custom_llm_provider") if custom_llm_provider is None: @@ -1119,8 +1131,8 @@ class Router: healthy_deployments = [m for m in self.model_list if m["model_name"] == model] if len(healthy_deployments) == 0: # check if the user sent in a deployment name instead - healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model] + self.print_verbose(f"initial list of deployments: {healthy_deployments}") deployments_to_remove = [] cooldown_deployments = self._get_cooldown_deployments() @@ -1140,13 +1152,24 @@ class Router: model = litellm.model_alias_map[ model ] # update the model to the actual value if an alias has been passed in - if self.routing_strategy == "least-busy": - if len(self.healthy_deployments) > 0: - for item in self.healthy_deployments: - if item[0]["model_name"] == model: # first one in queue will be the one with the most availability - return item[0] + if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: + deployments = self.leastbusy_logger.get_available_deployments(model_group=model) + # pick least busy deployment + min_traffic = float('inf') + min_deployment = None + for k, v in deployments.items(): + if v < min_traffic: + min_deployment = k + ############## No Available Deployments passed, we do a random pick ################# + if min_deployment is None: + min_deployment = random.choice(healthy_deployments) + ############## Available Deployments passed, we find the relevant item ################# else: - raise ValueError("No models available.") + for m in healthy_deployments: + if m["model_info"]["id"] == min_deployment: + return m + min_deployment = random.choice(healthy_deployments) + return min_deployment elif self.routing_strategy == "simple-shuffle": # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm ############## Check if we can do a RPM/TPM based weighted pick ################# diff --git a/litellm/router_strategy/least_busy.py b/litellm/router_strategy/least_busy.py new file mode 100644 index 000000000..0080e3fa8 --- /dev/null +++ b/litellm/router_strategy/least_busy.py @@ -0,0 +1,96 @@ +#### What this does #### +# identifies least busy deployment +# How is this achieved? +# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"} +# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic} +# - use litellm.success + failure callbacks to log when a request completed +# - in get_available_deployment, for a given model group name -> pick based on traffic + +import dotenv, os, requests +from typing import Optional +dotenv.load_dotenv() # Loading env variables using dotenv +import traceback +from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger + +class LeastBusyLoggingHandler(CustomLogger): + + def __init__(self, router_cache: DualCache): + self.router_cache = router_cache + self.mapping_deployment_to_id: dict = {} + + + def log_pre_api_call(self, model, messages, kwargs): + """ + Log when a model is being used. + + Caching based on model group. + """ + try: + + if kwargs['litellm_params'].get('metadata') is None: + pass + else: + deployment = kwargs['litellm_params']['metadata'].get('deployment', None) + model_group = kwargs['litellm_params']['metadata'].get('model_group', None) + id = kwargs['litellm_params'].get('model_info', {}).get('id', None) + if deployment is None or model_group is None or id is None: + return + + # map deployment to id + self.mapping_deployment_to_id[deployment] = id + + request_count_api_key = f"{model_group}_request_count" + # update cache + request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} + request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1 + self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) + except Exception as e: + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs['litellm_params'].get('metadata') is None: + pass + else: + deployment = kwargs['litellm_params']['metadata'].get('deployment', None) + model_group = kwargs['litellm_params']['metadata'].get('model_group', None) + if deployment is None or model_group is None: + return + + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} + request_count_dict[deployment] = request_count_dict.get(deployment) + self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) + except Exception as e: + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + if kwargs['litellm_params'].get('metadata') is None: + pass + else: + deployment = kwargs['litellm_params']['metadata'].get('deployment', None) + model_group = kwargs['litellm_params']['metadata'].get('model_group', None) + if deployment is None or model_group is None: + return + + + request_count_api_key = f"{model_group}_request_count" + # decrement count in cache + request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} + request_count_dict[deployment] = request_count_dict.get(deployment) + self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) + except Exception as e: + pass + + def get_available_deployments(self, model_group: str): + request_count_api_key = f"{model_group}_request_count" + request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} + # map deployment to id + return_dict = {} + for key, value in request_count_dict.items(): + return_dict[self.mapping_deployment_to_id[key]] = value + return return_dict \ No newline at end of file diff --git a/litellm/tests/test_least_busy_routing.py b/litellm/tests/test_least_busy_routing.py new file mode 100644 index 000000000..05d3f3ec6 --- /dev/null +++ b/litellm/tests/test_least_busy_routing.py @@ -0,0 +1,79 @@ +# #### What this tests #### +# # This tests the router's ability to identify the least busy deployment + +# # +# # How is this achieved? +# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"} +# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic} +# # - use litellm.success + failure callbacks to log when a request completed +# # - in get_available_deployment, for a given model group name -> pick based on traffic + +# import sys, os, asyncio, time +# import traceback +# from dotenv import load_dotenv + +# load_dotenv() +# import os + +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import pytest +# from litellm import Router +# import litellm + +# async def test_least_busy_routing(): +# model_list = [{ +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-turbo", +# "api_key": "os.environ/AZURE_FRANCE_API_KEY", +# "api_base": "https://openai-france-1234.openai.azure.com", +# "rpm": 1440, +# } +# }, { +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-35-turbo", +# "api_key": "os.environ/AZURE_EUROPE_API_KEY", +# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", +# "rpm": 6 +# } +# }, { +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-35-turbo", +# "api_key": "os.environ/AZURE_CANADA_API_KEY", +# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com", +# "rpm": 6 +# } +# }] +# router = Router(model_list=model_list, +# routing_strategy="least-busy", +# set_verbose=False, +# num_retries=3) # type: ignore + +# async def call_azure_completion(): +# try: +# response = await router.acompletion( +# model="azure-model", +# messages=[ +# { +# "role": "user", +# "content": "hello this request will pass" +# } +# ] +# ) +# print("\n response", response) +# return response +# except: +# return None + +# n = 1000 +# start_time = time.time() +# tasks = [call_azure_completion() for _ in range(n)] +# chat_completions = await asyncio.gather(*tasks) +# successful_completions = [c for c in chat_completions if c is not None] +# print(n, time.time() - start_time, len(successful_completions)) + +# asyncio.run(test_least_busy_routing()) \ No newline at end of file diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index e135f4228..6f43ed99f 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -226,10 +226,12 @@ def test_call_one_endpoint(): ) print("\n response", response) + tasks = [] + tasks.append(call_azure_completion()) + tasks.append(call_bedrock_claude()) + tasks.append(call_azure_embedding()) - asyncio.run(call_azure_completion()) - asyncio.run(call_bedrock_claude()) - asyncio.run(call_azure_embedding()) + asyncio.gather(**tasks) os.environ["AZURE_API_BASE"] = old_api_base os.environ["AZURE_API_KEY"] = old_api_key diff --git a/litellm/utils.py b/litellm/utils.py index 07bbc744a..b8e4dd069 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -562,19 +562,25 @@ class Logging: **self.optional_params } + def _pre_call(self, input, api_key, model=None, additional_args={}): + """ + Common helper function across the sync + async pre-call function + """ + # print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}") + self.model_call_details["input"] = input + self.model_call_details["api_key"] = api_key + self.model_call_details["additional_args"] = additional_args + self.model_call_details["log_event_type"] = "pre_api_call" + if ( + model + ): # if model name was changes pre-call, overwrite the initial model call name with the new one + self.model_call_details["model"] = model + def pre_call(self, input, api_key, model=None, additional_args={}): # Log the exact input to the LLM API litellm.error_logs['PRE_CALL'] = locals() try: - # print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}") - self.model_call_details["input"] = input - self.model_call_details["api_key"] = api_key - self.model_call_details["additional_args"] = additional_args - self.model_call_details["log_event_type"] = "pre_api_call" - if ( - model - ): # if model name was changes pre-call, overwrite the initial model call name with the new one - self.model_call_details["model"] = model + self._pre_call(input=input, api_key=api_key, model=model, additional_args=additional_args) # User Logging -> if you pass in a custom logging function headers = additional_args.get("headers", {}) @@ -688,6 +694,34 @@ class Logging: if capture_exception: # log this error to sentry for debugging capture_exception(e) + async def async_pre_call(self, result=None, start_time=None, end_time=None, **kwargs): + """ + Â Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. + """ + start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result) + print_verbose(f"Async input callbacks: {litellm._async_input_callback}") + for callback in litellm._async_input_callback: + try: + if isinstance(callback, CustomLogger): # custom logger class + print_verbose(f"Async input callbacks: CustomLogger") + asyncio.create_task(callback.async_log_input_event( + model=self.model, + messages=self.messages, + kwargs=self.model_call_details, + )) + if callable(callback): # custom logger functions + print_verbose(f"Async success callbacks: async_log_event") + asyncio.create_task(customLogger.async_log_input_event( + model=self.model, + messages=self.messages, + kwargs=self.model_call_details, + print_verbose=print_verbose, + callback_func=callback + )) + except: + print_verbose( + f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" + ) def post_call(self, original_response, input=None, api_key=None, additional_args={}): # Log the exact result from the LLM API, for streaming - log the type of response received litellm.error_logs['POST_CALL'] = locals() @@ -1289,6 +1323,17 @@ def client(original_function): function_id=function_id ) ## ASYNC CALLBACKS + if len(litellm.input_callback) > 0: + removed_async_items = [] + for index, callback in enumerate(litellm.input_callback): + if inspect.iscoroutinefunction(callback): + litellm._async_input_callback.append(callback) + removed_async_items.append(index) + + # Pop the async items from input_callback in reverse order to avoid index issues + for index in reversed(removed_async_items): + litellm.input_callback.pop(index) + if len(litellm.success_callback) > 0: removed_async_items = [] for index, callback in enumerate(litellm.success_callback): @@ -1307,7 +1352,7 @@ def client(original_function): litellm._async_failure_callback.append(callback) removed_async_items.append(index) - # Pop the async items from success_callback in reverse order to avoid index issues + # Pop the async items from failure_callback in reverse order to avoid index issues for index in reversed(removed_async_items): litellm.failure_callback.pop(index) if add_breadcrumb: