diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 80cfb03de4..98685e1a7c 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -1009,8 +1009,6 @@ async def get_key_object( ) return _response - except DB_CONNECTION_ERROR_TYPES as e: - return await _handle_failed_db_connection_for_get_key_object(e=e) except Exception: traceback.print_exc() raise Exception( @@ -1018,46 +1016,6 @@ async def get_key_object( ) -async def _handle_failed_db_connection_for_get_key_object( - e: Exception, -) -> UserAPIKeyAuth: - """ - Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB - - Use this if you don't want failed DB queries to block LLM API reqiests - - Returns: - - UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True - - Raises: - - Orignal Exception in all other cases - """ - from litellm.proxy.proxy_server import ( - general_settings, - litellm_proxy_admin_name, - proxy_logging_obj, - ) - - # If this flag is on, requests failing to connect to the DB will be allowed - if general_settings.get("allow_requests_on_db_unavailable", False) is True: - # log this as a DB failure on prometheus - proxy_logging_obj.service_logging_obj.service_failure_hook( - service=ServiceTypes.DB, - call_type="get_key_object", - error=e, - duration=0.0, - ) - - return UserAPIKeyAuth( - key_name="failed-to-connect-to-db", - token="failed-to-connect-to-db", - user_id=litellm_proxy_admin_name, - ) - else: - # raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus - raise e - - @log_db_metrics async def get_org_object( org_id: str, diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py new file mode 100644 index 0000000000..88a94fd5b9 --- /dev/null +++ b/litellm/proxy/auth/auth_exception_handler.py @@ -0,0 +1,130 @@ +""" +Handles Authentication Errors +""" + +import asyncio +from typing import TYPE_CHECKING, Any, Optional + +from fastapi import HTTPException, Request, status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth +from litellm.proxy.auth.auth_utils import _get_request_ip_address +from litellm.types.services import ServiceTypes + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class UserAPIKeyAuthExceptionHandler: + + @staticmethod + async def _handle_authentication_error( + e: Exception, + request: Request, + request_data: dict, + route: str, + parent_otel_span: Optional[Span], + api_key: str, + ) -> UserAPIKeyAuth: + """ + Handles Connection Errors when reading a Virtual Key from LiteLLM DB + Use this if you don't want failed DB queries to block LLM API reqiests + + Reliability scenarios this covers: + - DB is down and having an outage + - Unable to read / recover a key from the DB + + Returns: + - UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True + + Raises: + - Orignal Exception in all other cases + """ + from litellm.proxy.proxy_server import ( + general_settings, + litellm_proxy_admin_name, + proxy_logging_obj, + ) + + if UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable(): + # log this as a DB failure on prometheus + proxy_logging_obj.service_logging_obj.service_failure_hook( + service=ServiceTypes.DB, + call_type="get_key_object", + error=e, + duration=0.0, + ) + + return UserAPIKeyAuth( + key_name="failed-to-connect-to-db", + token="failed-to-connect-to-db", + user_id=litellm_proxy_admin_name, + ) + else: + # raise the exception to the caller + requester_ip = _get_request_ip_address( + request=request, + use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format( + str(e), + requester_ip, + ), + extra={"requester_ip": requester_ip}, + ) + + # Log this exception to OTEL, Datadog etc + user_api_key_dict = UserAPIKeyAuth( + parent_otel_span=parent_otel_span, + api_key=api_key, + ) + asyncio.create_task( + proxy_logging_obj.post_call_failure_hook( + request_data=request_data, + original_exception=e, + user_api_key_dict=user_api_key_dict, + error_type=ProxyErrorTypes.auth_error, + route=route, + ) + ) + + if isinstance(e, litellm.BudgetExceededError): + raise ProxyException( + message=e.message, + type=ProxyErrorTypes.budget_exceeded, + param=None, + code=400, + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_401_UNAUTHORIZED, + ) + + @staticmethod + def should_allow_request_on_db_unavailable() -> bool: + """ + Returns True if the request should be allowed to proceed despite the DB connection error + """ + from litellm.proxy.proxy_server import general_settings + + if general_settings.get("allow_requests_on_db_unavailable", False) is True: + return True + return False diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index b78619ae65..a2850ca294 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -26,7 +26,6 @@ from litellm.proxy._types import * from litellm.proxy.auth.auth_checks import ( _cache_key_object, _get_user_role, - _handle_failed_db_connection_for_get_key_object, _is_user_proxy_admin, _virtual_key_max_budget_check, _virtual_key_soft_budget_check, @@ -38,6 +37,7 @@ from litellm.proxy.auth.auth_checks import ( get_user_object, is_valid_fallback_model, ) +from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler from litellm.proxy.auth.auth_utils import ( _get_request_ip_address, get_end_user_id_from_request_body, @@ -675,9 +675,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 if ( prisma_client is None ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error - return await _handle_failed_db_connection_for_get_key_object( - e=Exception("No connected db.") - ) + raise Exception("No connected db.") ## check for cache hit (In-Memory Cache) _user_role = None @@ -1018,55 +1016,14 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 else: raise Exception() except Exception as e: - requester_ip = _get_request_ip_address( + return await UserAPIKeyAuthExceptionHandler._handle_authentication_error( + e=e, request=request, - use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), - ) - verbose_proxy_logger.exception( - "litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format( - str(e), - requester_ip, - ), - extra={"requester_ip": requester_ip}, - ) - - # Log this exception to OTEL, Datadog etc - user_api_key_dict = UserAPIKeyAuth( + request_data=request_data, + route=route, parent_otel_span=parent_otel_span, api_key=api_key, ) - asyncio.create_task( - proxy_logging_obj.post_call_failure_hook( - request_data=request_data, - original_exception=e, - user_api_key_dict=user_api_key_dict, - error_type=ProxyErrorTypes.auth_error, - route=route, - ) - ) - - if isinstance(e, litellm.BudgetExceededError): - raise ProxyException( - message=e.message, - type=ProxyErrorTypes.budget_exceeded, - param=None, - code=400, - ) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type=ProxyErrorTypes.auth_error, - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type=ProxyErrorTypes.auth_error, - param=getattr(e, "param", "None"), - code=status.HTTP_401_UNAUTHORIZED, - ) @tracer.wrap()