fix(custom_logger.py): enable pre_call hooks to modify incoming data to proxy

This commit is contained in:
Krrish Dholakia 2023-12-13 16:20:13 -08:00
parent 03d6dcefbb
commit effdddc1c8
4 changed files with 51 additions and 43 deletions

View file

@ -4,7 +4,7 @@ import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.integrations.custom_logger import CustomLogger
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
@ -65,16 +65,12 @@ class ProxyLogging:
2. /embeddings
"""
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__):
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type)
if response is not None:
data = response
print_verbose(f'final data being sent to {call_type} call: {data}')
return data
except Exception as e:
@ -103,17 +99,13 @@ class ProxyLogging:
1. /chat/completions
2. /embeddings
"""
# check if max parallel requests set
if user_api_key_dict is not None and user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
await self.max_parallel_request_limiter.async_log_failure_call(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception)
except Exception as e:
raise e
return