diff --git a/litellm/proxy/hooks/failure_handler.py b/litellm/proxy/hooks/failure_handler.py new file mode 100644 index 000000000..36e0fb0e6 --- /dev/null +++ b/litellm/proxy/hooks/failure_handler.py @@ -0,0 +1,81 @@ +""" +Runs when LLM Exceptions occur on LiteLLM Proxy +""" + +import copy +import json +import uuid + +import litellm +from litellm.proxy._types import LiteLLM_ErrorLogs + + +async def _PROXY_failure_handler( + kwargs, # kwargs to completion + completion_response: litellm.ModelResponse, # response from completion + start_time=None, + end_time=None, # start/end time for completion +): + """ + Async Failure Handler - runs when LLM Exceptions occur on LiteLLM Proxy. + + This function logs the errors to the Prisma DB + """ + from litellm._logging import verbose_proxy_logger + from litellm.proxy.proxy_server import general_settings, prisma_client + + if general_settings.get("disable_error_logs") is True: + return + + if prisma_client is not None: + verbose_proxy_logger.debug( + "inside _PROXY_failure_handler kwargs=", extra=kwargs + ) + + _exception = kwargs.get("exception") + _exception_type = _exception.__class__.__name__ + _model = kwargs.get("model", None) + + _optional_params = kwargs.get("optional_params", {}) + _optional_params = copy.deepcopy(_optional_params) + + for k, v in _optional_params.items(): + v = str(v) + v = v[:100] + + _status_code = "500" + try: + _status_code = str(_exception.status_code) + except Exception: + # Don't let this fail logging the exception to the dB + pass + + _litellm_params = kwargs.get("litellm_params", {}) or {} + _metadata = _litellm_params.get("metadata", {}) or {} + _model_id = _metadata.get("model_info", {}).get("id", "") + _model_group = _metadata.get("model_group", "") + api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params) + _exception_string = str(_exception) + + error_log = LiteLLM_ErrorLogs( + request_id=str(uuid.uuid4()), + model_group=_model_group, + model_id=_model_id, + litellm_model_name=kwargs.get("model"), + request_kwargs=_optional_params, + api_base=api_base, + exception_type=_exception_type, + status_code=_status_code, + exception_string=_exception_string, + startTime=kwargs.get("start_time"), + endTime=kwargs.get("end_time"), + ) + + error_log_dict = error_log.model_dump() + error_log_dict["request_kwargs"] = json.dumps(error_log_dict["request_kwargs"]) + + await prisma_client.db.litellm_errorlogs.create( + data=error_log_dict # type: ignore + ) + + pass diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 15971263a..011ed04de 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -170,6 +170,7 @@ from litellm.proxy.guardrails.init_guardrails import ( ) from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_endpoints._health_endpoints import router as health_router +from litellm.proxy.hooks.failure_handler import _PROXY_failure_handler from litellm.proxy.hooks.prompt_injection_detection import ( _OPTIONAL_PromptInjectionDetection, ) @@ -526,14 +527,6 @@ db_writer_client: Optional[HTTPHandler] = None ### logger ### -def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: - try: - return pydantic_obj.model_dump() # type: ignore - except Exception: - # if using pydantic v1 - return pydantic_obj.dict() - - def get_custom_headers( *, user_api_key_dict: UserAPIKeyAuth, @@ -687,68 +680,6 @@ def cost_tracking(): litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore -async def _PROXY_failure_handler( - kwargs, # kwargs to completion - completion_response: litellm.ModelResponse, # response from completion - start_time=None, - end_time=None, # start/end time for completion -): - global prisma_client - if prisma_client is not None: - verbose_proxy_logger.debug( - "inside _PROXY_failure_handler kwargs=", extra=kwargs - ) - - _exception = kwargs.get("exception") - _exception_type = _exception.__class__.__name__ - _model = kwargs.get("model", None) - - _optional_params = kwargs.get("optional_params", {}) - _optional_params = copy.deepcopy(_optional_params) - - for k, v in _optional_params.items(): - v = str(v) - v = v[:100] - - _status_code = "500" - try: - _status_code = str(_exception.status_code) - except Exception: - # Don't let this fail logging the exception to the dB - pass - - _litellm_params = kwargs.get("litellm_params", {}) or {} - _metadata = _litellm_params.get("metadata", {}) or {} - _model_id = _metadata.get("model_info", {}).get("id", "") - _model_group = _metadata.get("model_group", "") - api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params) - _exception_string = str(_exception) - - error_log = LiteLLM_ErrorLogs( - request_id=str(uuid.uuid4()), - model_group=_model_group, - model_id=_model_id, - litellm_model_name=kwargs.get("model"), - request_kwargs=_optional_params, - api_base=api_base, - exception_type=_exception_type, - status_code=_status_code, - exception_string=_exception_string, - startTime=kwargs.get("start_time"), - endTime=kwargs.get("end_time"), - ) - - # helper function to convert to dict on pydantic v2 & v1 - error_log_dict = _get_pydantic_json_dict(error_log) - error_log_dict["request_kwargs"] = json.dumps(error_log_dict["request_kwargs"]) - - await prisma_client.db.litellm_errorlogs.create( - data=error_log_dict # type: ignore - ) - - pass - - @log_db_metrics async def _PROXY_track_cost_callback( kwargs, # kwargs to completion