fix(router.py): fix least-busy routing

This commit is contained in:
Krrish Dholakia 2023-12-08 20:29:37 -08:00
parent d9b115b8fb
commit 4bf875d3ed
8 changed files with 292 additions and 31 deletions

View file

@ -562,19 +562,25 @@ class Logging:
**self.optional_params
}
def _pre_call(self, input, api_key, model=None, additional_args={}):
"""
Common helper function across the sync + async pre-call function
"""
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args
self.model_call_details["log_event_type"] = "pre_api_call"
if (
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
def pre_call(self, input, api_key, model=None, additional_args={}):
# Log the exact input to the LLM API
litellm.error_logs['PRE_CALL'] = locals()
try:
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args
self.model_call_details["log_event_type"] = "pre_api_call"
if (
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
self._pre_call(input=input, api_key=api_key, model=model, additional_args=additional_args)
# User Logging -> if you pass in a custom logging function
headers = additional_args.get("headers", {})
@ -688,6 +694,34 @@ class Logging:
if capture_exception: # log this error to sentry for debugging
capture_exception(e)
async def async_pre_call(self, result=None, start_time=None, end_time=None, **kwargs):
"""
 Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
print_verbose(f"Async input callbacks: {litellm._async_input_callback}")
for callback in litellm._async_input_callback:
try:
if isinstance(callback, CustomLogger): # custom logger class
print_verbose(f"Async input callbacks: CustomLogger")
asyncio.create_task(callback.async_log_input_event(
model=self.model,
messages=self.messages,
kwargs=self.model_call_details,
))
if callable(callback): # custom logger functions
print_verbose(f"Async success callbacks: async_log_event")
asyncio.create_task(customLogger.async_log_input_event(
model=self.model,
messages=self.messages,
kwargs=self.model_call_details,
print_verbose=print_verbose,
callback_func=callback
))
except:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
def post_call(self, original_response, input=None, api_key=None, additional_args={}):
# Log the exact result from the LLM API, for streaming - log the type of response received
litellm.error_logs['POST_CALL'] = locals()
@ -1289,6 +1323,17 @@ def client(original_function):
function_id=function_id
)
## ASYNC CALLBACKS
if len(litellm.input_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.input_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_input_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from input_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.input_callback.pop(index)
if len(litellm.success_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.success_callback):
@ -1307,7 +1352,7 @@ def client(original_function):
litellm._async_failure_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
# Pop the async items from failure_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.failure_callback.pop(index)
if add_breadcrumb: