mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(router.py): fix least-busy routing
This commit is contained in:
parent
d9b115b8fb
commit
4bf875d3ed
8 changed files with 292 additions and 31 deletions
|
@ -9,6 +9,7 @@ input_callback: List[Union[str, Callable]] = []
|
||||||
success_callback: List[Union[str, Callable]] = []
|
success_callback: List[Union[str, Callable]] = []
|
||||||
failure_callback: List[Union[str, Callable]] = []
|
failure_callback: List[Union[str, Callable]] = []
|
||||||
callbacks: List[Callable] = []
|
callbacks: List[Callable] = []
|
||||||
|
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||||
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||||
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||||
pre_call_rules: List[Callable] = []
|
pre_call_rules: List[Callable] = []
|
||||||
|
|
|
@ -61,8 +61,6 @@ class InMemoryCache(BaseCache):
|
||||||
cached_response = json.loads(original_cached_response)
|
cached_response = json.loads(original_cached_response)
|
||||||
except:
|
except:
|
||||||
cached_response = original_cached_response
|
cached_response = original_cached_response
|
||||||
if isinstance(cached_response, dict):
|
|
||||||
cached_response['cache'] = True # set cache-hit flag to True
|
|
||||||
return cached_response
|
return cached_response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -110,8 +108,6 @@ class RedisCache(BaseCache):
|
||||||
cached_response = json.loads(cached_response) # Convert string to dictionary
|
cached_response = json.loads(cached_response) # Convert string to dictionary
|
||||||
except:
|
except:
|
||||||
cached_response = ast.literal_eval(cached_response)
|
cached_response = ast.literal_eval(cached_response)
|
||||||
if isinstance(cached_response, dict):
|
|
||||||
cached_response['cache'] = True # set cache-hit flag to True
|
|
||||||
return cached_response
|
return cached_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# NON blocking - notify users Redis is throwing an exception
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
|
|
@ -28,6 +28,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
|
||||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -51,6 +54,22 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||||
|
|
||||||
|
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||||
|
try:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["messages"] = messages
|
||||||
|
kwargs["log_event_type"] = "pre_api_call"
|
||||||
|
await callback_func(
|
||||||
|
kwargs,
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"Custom Logger - model call details: {kwargs}"
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
|
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
|
||||||
# Method definition
|
# Method definition
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -9,14 +9,14 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union, Literal, Any
|
from typing import Dict, List, Optional, Union, Literal, Any
|
||||||
import random, threading, time, traceback
|
import random, threading, time, traceback, uuid
|
||||||
import litellm, openai
|
import litellm, openai
|
||||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||||
import logging, asyncio
|
import logging, asyncio
|
||||||
import inspect, concurrent
|
import inspect, concurrent
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
class Router:
|
class Router:
|
||||||
"""
|
"""
|
||||||
Example usage:
|
Example usage:
|
||||||
|
@ -57,6 +57,7 @@ class Router:
|
||||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||||
num_retries: int = 0
|
num_retries: int = 0
|
||||||
tenacity = None
|
tenacity = None
|
||||||
|
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_list: Optional[list] = None,
|
model_list: Optional[list] = None,
|
||||||
|
@ -107,10 +108,6 @@ class Router:
|
||||||
self.default_litellm_params.setdefault("timeout", timeout)
|
self.default_litellm_params.setdefault("timeout", timeout)
|
||||||
self.default_litellm_params.setdefault("max_retries", 0)
|
self.default_litellm_params.setdefault("max_retries", 0)
|
||||||
|
|
||||||
|
|
||||||
### HEALTH CHECK THREAD ###
|
|
||||||
if self.routing_strategy == "least-busy":
|
|
||||||
self._start_health_check_thread()
|
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
cache_type = "local" # default to an in-memory cache
|
cache_type = "local" # default to an in-memory cache
|
||||||
redis_cache = None
|
redis_cache = None
|
||||||
|
@ -137,6 +134,16 @@ class Router:
|
||||||
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
|
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
|
||||||
self.cache_responses = cache_responses
|
self.cache_responses = cache_responses
|
||||||
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||||
|
### ROUTING SETUP ###
|
||||||
|
if routing_strategy == "least-busy":
|
||||||
|
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
|
||||||
|
## add callback
|
||||||
|
if isinstance(litellm.input_callback, list):
|
||||||
|
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
|
||||||
|
else:
|
||||||
|
litellm.input_callback = [self.leastbusy_logger] # type: ignore
|
||||||
|
if isinstance(litellm.callbacks, list):
|
||||||
|
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
|
||||||
## USAGE TRACKING ##
|
## USAGE TRACKING ##
|
||||||
if isinstance(litellm.success_callback, list):
|
if isinstance(litellm.success_callback, list):
|
||||||
litellm.success_callback.append(self.deployment_callback)
|
litellm.success_callback.append(self.deployment_callback)
|
||||||
|
@ -664,6 +671,7 @@ class Router:
|
||||||
return kwargs
|
return kwargs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _set_cooldown_deployments(self,
|
def _set_cooldown_deployments(self,
|
||||||
deployment: str):
|
deployment: str):
|
||||||
"""
|
"""
|
||||||
|
@ -873,6 +881,10 @@ class Router:
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
litellm_params = model.get("litellm_params", {})
|
litellm_params = model.get("litellm_params", {})
|
||||||
model_name = litellm_params.get("model")
|
model_name = litellm_params.get("model")
|
||||||
|
#### MODEL ID INIT ########
|
||||||
|
model_info = model.get("model_info", {})
|
||||||
|
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
||||||
|
model["model_info"] = model_info
|
||||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||||
if custom_llm_provider is None:
|
if custom_llm_provider is None:
|
||||||
|
@ -1119,8 +1131,8 @@ class Router:
|
||||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
# check if the user sent in a deployment name instead
|
# check if the user sent in a deployment name instead
|
||||||
|
|
||||||
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||||
|
|
||||||
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||||
deployments_to_remove = []
|
deployments_to_remove = []
|
||||||
cooldown_deployments = self._get_cooldown_deployments()
|
cooldown_deployments = self._get_cooldown_deployments()
|
||||||
|
@ -1140,13 +1152,24 @@ class Router:
|
||||||
model = litellm.model_alias_map[
|
model = litellm.model_alias_map[
|
||||||
model
|
model
|
||||||
] # update the model to the actual value if an alias has been passed in
|
] # update the model to the actual value if an alias has been passed in
|
||||||
if self.routing_strategy == "least-busy":
|
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||||
if len(self.healthy_deployments) > 0:
|
deployments = self.leastbusy_logger.get_available_deployments(model_group=model)
|
||||||
for item in self.healthy_deployments:
|
# pick least busy deployment
|
||||||
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
min_traffic = float('inf')
|
||||||
return item[0]
|
min_deployment = None
|
||||||
|
for k, v in deployments.items():
|
||||||
|
if v < min_traffic:
|
||||||
|
min_deployment = k
|
||||||
|
############## No Available Deployments passed, we do a random pick #################
|
||||||
|
if min_deployment is None:
|
||||||
|
min_deployment = random.choice(healthy_deployments)
|
||||||
|
############## Available Deployments passed, we find the relevant item #################
|
||||||
else:
|
else:
|
||||||
raise ValueError("No models available.")
|
for m in healthy_deployments:
|
||||||
|
if m["model_info"]["id"] == min_deployment:
|
||||||
|
return m
|
||||||
|
min_deployment = random.choice(healthy_deployments)
|
||||||
|
return min_deployment
|
||||||
elif self.routing_strategy == "simple-shuffle":
|
elif self.routing_strategy == "simple-shuffle":
|
||||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||||
|
|
96
litellm/router_strategy/least_busy.py
Normal file
96
litellm/router_strategy/least_busy.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
#### What this does ####
|
||||||
|
# identifies least busy deployment
|
||||||
|
# How is this achieved?
|
||||||
|
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||||
|
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||||
|
# - use litellm.success + failure callbacks to log when a request completed
|
||||||
|
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||||
|
|
||||||
|
import dotenv, os, requests
|
||||||
|
from typing import Optional
|
||||||
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
import traceback
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
class LeastBusyLoggingHandler(CustomLogger):
|
||||||
|
|
||||||
|
def __init__(self, router_cache: DualCache):
|
||||||
|
self.router_cache = router_cache
|
||||||
|
self.mapping_deployment_to_id: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
"""
|
||||||
|
Log when a model is being used.
|
||||||
|
|
||||||
|
Caching based on model group.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
if kwargs['litellm_params'].get('metadata') is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||||
|
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||||
|
id = kwargs['litellm_params'].get('model_info', {}).get('id', None)
|
||||||
|
if deployment is None or model_group is None or id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# map deployment to id
|
||||||
|
self.mapping_deployment_to_id[deployment] = id
|
||||||
|
|
||||||
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
# update cache
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
|
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1
|
||||||
|
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
if kwargs['litellm_params'].get('metadata') is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||||
|
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||||
|
if deployment is None or model_group is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
# decrement count in cache
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
|
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||||
|
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
if kwargs['litellm_params'].get('metadata') is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||||
|
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||||
|
if deployment is None or model_group is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
# decrement count in cache
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
|
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||||
|
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_available_deployments(self, model_group: str):
|
||||||
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
|
# map deployment to id
|
||||||
|
return_dict = {}
|
||||||
|
for key, value in request_count_dict.items():
|
||||||
|
return_dict[self.mapping_deployment_to_id[key]] = value
|
||||||
|
return return_dict
|
79
litellm/tests/test_least_busy_routing.py
Normal file
79
litellm/tests/test_least_busy_routing.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
# #### What this tests ####
|
||||||
|
# # This tests the router's ability to identify the least busy deployment
|
||||||
|
|
||||||
|
# #
|
||||||
|
# # How is this achieved?
|
||||||
|
# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||||
|
# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||||
|
# # - use litellm.success + failure callbacks to log when a request completed
|
||||||
|
# # - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||||
|
|
||||||
|
# import sys, os, asyncio, time
|
||||||
|
# import traceback
|
||||||
|
# from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# load_dotenv()
|
||||||
|
# import os
|
||||||
|
|
||||||
|
# sys.path.insert(
|
||||||
|
# 0, os.path.abspath("../..")
|
||||||
|
# ) # Adds the parent directory to the system path
|
||||||
|
# import pytest
|
||||||
|
# from litellm import Router
|
||||||
|
# import litellm
|
||||||
|
|
||||||
|
# async def test_least_busy_routing():
|
||||||
|
# model_list = [{
|
||||||
|
# "model_name": "azure-model",
|
||||||
|
# "litellm_params": {
|
||||||
|
# "model": "azure/gpt-turbo",
|
||||||
|
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
# "api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
# "rpm": 1440,
|
||||||
|
# }
|
||||||
|
# }, {
|
||||||
|
# "model_name": "azure-model",
|
||||||
|
# "litellm_params": {
|
||||||
|
# "model": "azure/gpt-35-turbo",
|
||||||
|
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
|
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
|
# "rpm": 6
|
||||||
|
# }
|
||||||
|
# }, {
|
||||||
|
# "model_name": "azure-model",
|
||||||
|
# "litellm_params": {
|
||||||
|
# "model": "azure/gpt-35-turbo",
|
||||||
|
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||||
|
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||||
|
# "rpm": 6
|
||||||
|
# }
|
||||||
|
# }]
|
||||||
|
# router = Router(model_list=model_list,
|
||||||
|
# routing_strategy="least-busy",
|
||||||
|
# set_verbose=False,
|
||||||
|
# num_retries=3) # type: ignore
|
||||||
|
|
||||||
|
# async def call_azure_completion():
|
||||||
|
# try:
|
||||||
|
# response = await router.acompletion(
|
||||||
|
# model="azure-model",
|
||||||
|
# messages=[
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "hello this request will pass"
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
# print("\n response", response)
|
||||||
|
# return response
|
||||||
|
# except:
|
||||||
|
# return None
|
||||||
|
|
||||||
|
# n = 1000
|
||||||
|
# start_time = time.time()
|
||||||
|
# tasks = [call_azure_completion() for _ in range(n)]
|
||||||
|
# chat_completions = await asyncio.gather(*tasks)
|
||||||
|
# successful_completions = [c for c in chat_completions if c is not None]
|
||||||
|
# print(n, time.time() - start_time, len(successful_completions))
|
||||||
|
|
||||||
|
# asyncio.run(test_least_busy_routing())
|
|
@ -226,10 +226,12 @@ def test_call_one_endpoint():
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n response", response)
|
print("\n response", response)
|
||||||
|
tasks = []
|
||||||
|
tasks.append(call_azure_completion())
|
||||||
|
tasks.append(call_bedrock_claude())
|
||||||
|
tasks.append(call_azure_embedding())
|
||||||
|
|
||||||
asyncio.run(call_azure_completion())
|
asyncio.gather(**tasks)
|
||||||
asyncio.run(call_bedrock_claude())
|
|
||||||
asyncio.run(call_azure_embedding())
|
|
||||||
|
|
||||||
os.environ["AZURE_API_BASE"] = old_api_base
|
os.environ["AZURE_API_BASE"] = old_api_base
|
||||||
os.environ["AZURE_API_KEY"] = old_api_key
|
os.environ["AZURE_API_KEY"] = old_api_key
|
||||||
|
|
|
@ -562,10 +562,10 @@ class Logging:
|
||||||
**self.optional_params
|
**self.optional_params
|
||||||
}
|
}
|
||||||
|
|
||||||
def pre_call(self, input, api_key, model=None, additional_args={}):
|
def _pre_call(self, input, api_key, model=None, additional_args={}):
|
||||||
# Log the exact input to the LLM API
|
"""
|
||||||
litellm.error_logs['PRE_CALL'] = locals()
|
Common helper function across the sync + async pre-call function
|
||||||
try:
|
"""
|
||||||
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
|
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
|
||||||
self.model_call_details["input"] = input
|
self.model_call_details["input"] = input
|
||||||
self.model_call_details["api_key"] = api_key
|
self.model_call_details["api_key"] = api_key
|
||||||
|
@ -576,6 +576,12 @@ class Logging:
|
||||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||||
self.model_call_details["model"] = model
|
self.model_call_details["model"] = model
|
||||||
|
|
||||||
|
def pre_call(self, input, api_key, model=None, additional_args={}):
|
||||||
|
# Log the exact input to the LLM API
|
||||||
|
litellm.error_logs['PRE_CALL'] = locals()
|
||||||
|
try:
|
||||||
|
self._pre_call(input=input, api_key=api_key, model=model, additional_args=additional_args)
|
||||||
|
|
||||||
# User Logging -> if you pass in a custom logging function
|
# User Logging -> if you pass in a custom logging function
|
||||||
headers = additional_args.get("headers", {})
|
headers = additional_args.get("headers", {})
|
||||||
if headers is None:
|
if headers is None:
|
||||||
|
@ -688,6 +694,34 @@ class Logging:
|
||||||
if capture_exception: # log this error to sentry for debugging
|
if capture_exception: # log this error to sentry for debugging
|
||||||
capture_exception(e)
|
capture_exception(e)
|
||||||
|
|
||||||
|
async def async_pre_call(self, result=None, start_time=None, end_time=None, **kwargs):
|
||||||
|
"""
|
||||||
|
 Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||||
|
"""
|
||||||
|
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
|
||||||
|
print_verbose(f"Async input callbacks: {litellm._async_input_callback}")
|
||||||
|
for callback in litellm._async_input_callback:
|
||||||
|
try:
|
||||||
|
if isinstance(callback, CustomLogger): # custom logger class
|
||||||
|
print_verbose(f"Async input callbacks: CustomLogger")
|
||||||
|
asyncio.create_task(callback.async_log_input_event(
|
||||||
|
model=self.model,
|
||||||
|
messages=self.messages,
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
))
|
||||||
|
if callable(callback): # custom logger functions
|
||||||
|
print_verbose(f"Async success callbacks: async_log_event")
|
||||||
|
asyncio.create_task(customLogger.async_log_input_event(
|
||||||
|
model=self.model,
|
||||||
|
messages=self.messages,
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
callback_func=callback
|
||||||
|
))
|
||||||
|
except:
|
||||||
|
print_verbose(
|
||||||
|
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
||||||
|
)
|
||||||
def post_call(self, original_response, input=None, api_key=None, additional_args={}):
|
def post_call(self, original_response, input=None, api_key=None, additional_args={}):
|
||||||
# Log the exact result from the LLM API, for streaming - log the type of response received
|
# Log the exact result from the LLM API, for streaming - log the type of response received
|
||||||
litellm.error_logs['POST_CALL'] = locals()
|
litellm.error_logs['POST_CALL'] = locals()
|
||||||
|
@ -1289,6 +1323,17 @@ def client(original_function):
|
||||||
function_id=function_id
|
function_id=function_id
|
||||||
)
|
)
|
||||||
## ASYNC CALLBACKS
|
## ASYNC CALLBACKS
|
||||||
|
if len(litellm.input_callback) > 0:
|
||||||
|
removed_async_items = []
|
||||||
|
for index, callback in enumerate(litellm.input_callback):
|
||||||
|
if inspect.iscoroutinefunction(callback):
|
||||||
|
litellm._async_input_callback.append(callback)
|
||||||
|
removed_async_items.append(index)
|
||||||
|
|
||||||
|
# Pop the async items from input_callback in reverse order to avoid index issues
|
||||||
|
for index in reversed(removed_async_items):
|
||||||
|
litellm.input_callback.pop(index)
|
||||||
|
|
||||||
if len(litellm.success_callback) > 0:
|
if len(litellm.success_callback) > 0:
|
||||||
removed_async_items = []
|
removed_async_items = []
|
||||||
for index, callback in enumerate(litellm.success_callback):
|
for index, callback in enumerate(litellm.success_callback):
|
||||||
|
@ -1307,7 +1352,7 @@ def client(original_function):
|
||||||
litellm._async_failure_callback.append(callback)
|
litellm._async_failure_callback.append(callback)
|
||||||
removed_async_items.append(index)
|
removed_async_items.append(index)
|
||||||
|
|
||||||
# Pop the async items from success_callback in reverse order to avoid index issues
|
# Pop the async items from failure_callback in reverse order to avoid index issues
|
||||||
for index in reversed(removed_async_items):
|
for index in reversed(removed_async_items):
|
||||||
litellm.failure_callback.pop(index)
|
litellm.failure_callback.pop(index)
|
||||||
if add_breadcrumb:
|
if add_breadcrumb:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue