diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index 4f88fdb39b..ed65e14b8b 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -1776,6 +1776,7 @@ response = completion( ) ``` + 1. Setup config.yaml @@ -1820,11 +1821,13 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \ ``` + ### SSO Login (AWS Profile) - Set `AWS_PROFILE` environment variable - Make bedrock completion call + ```python import os from litellm import completion @@ -1940,9 +1943,6 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/images/generations' \ "colorGuidedGenerationParams":{"colors":["#FFFFFF"]} }' ``` - - - | Model Name | Function Call | |-------------------------|---------------------------------------------| diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 0093464d93..4a62184df7 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -160,7 +160,7 @@ general_settings: | database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) | | database_connection_pool_limit | integer | The limit for database connection pool [Setting DB Connection Pool limit](#configure-db-pool-limits--connection-timeouts) | | database_connection_timeout | integer | The timeout for database connections in seconds [Setting DB Connection Pool limit, timeout](#configure-db-pool-limits--connection-timeouts) | -| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key | +| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key [Doc on graceful db unavailability](prod#5-if-running-litellm-on-vpc-gracefully-handle-db-unavailability) | | custom_auth | string | Write your own custom authentication logic [Doc Custom Auth](virtual_keys#custom-auth) | | max_parallel_requests | integer | The max parallel requests allowed per deployment | | global_max_parallel_requests | integer | The max parallel requests allowed on the proxy overall | diff --git a/docs/my-website/docs/proxy/prod.md b/docs/my-website/docs/proxy/prod.md index d3ba2d6224..314300f2a0 100644 --- a/docs/my-website/docs/proxy/prod.md +++ b/docs/my-website/docs/proxy/prod.md @@ -94,15 +94,29 @@ This disables the load_dotenv() functionality, which will automatically load you ## 5. If running LiteLLM on VPC, gracefully handle DB unavailability -This will allow LiteLLM to continue to process requests even if the DB is unavailable. This is better handling for DB unavailability. +When running LiteLLM on a VPC (and inaccessible from the public internet), you can enable graceful degradation so that request processing continues even if the database is temporarily unavailable. + **WARNING: Only do this if you're running LiteLLM on VPC, that cannot be accessed from the public internet.** -```yaml +#### Configuration + +```yaml showLineNumbers title="litellm config.yaml" general_settings: allow_requests_on_db_unavailable: True ``` +#### Expected Behavior + +When `allow_requests_on_db_unavailable` is set to `true`, LiteLLM will handle errors as follows: + +| Type of Error | Expected Behavior | Details | +|---------------|-------------------|----------------| +| Prisma Errors | ✅ Request will be allowed | Covers issues like DB connection resets or rejections from the DB via Prisma, the ORM used by LiteLLM. | +| Httpx Errors | ✅ Request will be allowed | Occurs when the database is unreachable, allowing the request to proceed despite the DB outage. | +| LiteLLM Budget Errors or Model Errors | ❌ Request will be blocked | Triggered when the DB is reachable but the authentication token is invalid, lacks access, or exceeds budget limits. | + + ## 6. Disable spend_logs & error_logs if not using the LiteLLM UI By default, LiteLLM writes several types of logs to the database: @@ -182,94 +196,4 @@ You should only see the following level of details in logs on the proxy server # INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK # INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK # INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK -``` - - -### Machine Specifications to Deploy LiteLLM - -| Service | Spec | CPUs | Memory | Architecture | Version| -| --- | --- | --- | --- | --- | --- | -| Server | `t2.small`. | `1vCPUs` | `8GB` | `x86` | -| Redis Cache | - | - | - | - | 7.0+ Redis Engine| - - -### Reference Kubernetes Deployment YAML - -Reference Kubernetes `deployment.yaml` that was load tested by us - -```yaml -apiVersion: apps/v1 -kind: Deployment -metadata: - name: litellm-deployment -spec: - replicas: 3 - selector: - matchLabels: - app: litellm - template: - metadata: - labels: - app: litellm - spec: - containers: - - name: litellm-container - image: ghcr.io/berriai/litellm:main-latest - imagePullPolicy: Always - env: - - name: AZURE_API_KEY - value: "d6******" - - name: AZURE_API_BASE - value: "https://ope******" - - name: LITELLM_MASTER_KEY - value: "sk-1234" - - name: DATABASE_URL - value: "po**********" - args: - - "--config" - - "/app/proxy_config.yaml" # Update the path to mount the config file - volumeMounts: # Define volume mount for proxy_config.yaml - - name: config-volume - mountPath: /app - readOnly: true - livenessProbe: - httpGet: - path: /health/liveliness - port: 4000 - initialDelaySeconds: 120 - periodSeconds: 15 - successThreshold: 1 - failureThreshold: 3 - timeoutSeconds: 10 - readinessProbe: - httpGet: - path: /health/readiness - port: 4000 - initialDelaySeconds: 120 - periodSeconds: 15 - successThreshold: 1 - failureThreshold: 3 - timeoutSeconds: 10 - volumes: # Define volume to mount proxy_config.yaml - - name: config-volume - configMap: - name: litellm-config - -``` - - -Reference Kubernetes `service.yaml` that was load tested by us -```yaml -apiVersion: v1 -kind: Service -metadata: - name: litellm-service -spec: - selector: - app: litellm - ports: - - protocol: TCP - port: 4000 - targetPort: 4000 - type: LoadBalancer -``` +``` \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ebe7d1e955..e6294baab8 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2067,16 +2067,68 @@ class SpendCalculateRequest(LiteLLMPydanticObjectBase): class ProxyErrorTypes(str, enum.Enum): budget_exceeded = "budget_exceeded" + """ + Object was over budget + """ + no_db_connection = "no_db_connection" + """ + No database connection + """ + + token_not_found_in_db = "token_not_found_in_db" + """ + Requested token was not found in the database + """ + key_model_access_denied = "key_model_access_denied" + """ + Key does not have access to the model + """ + team_model_access_denied = "team_model_access_denied" + """ + Team does not have access to the model + """ + user_model_access_denied = "user_model_access_denied" + """ + User does not have access to the model + """ + expired_key = "expired_key" + """ + Key has expired + """ + auth_error = "auth_error" + """ + General authentication error + """ + internal_server_error = "internal_server_error" + """ + Internal server error + """ + bad_request_error = "bad_request_error" + """ + Bad request error + """ + not_found_error = "not_found_error" - validation_error = "bad_request_error" + """ + Not found error + """ + + validation_error = "validation_error" + """ + Validation error + """ + cache_ping_error = "cache_ping_error" + """ + Cache ping error + """ @classmethod def get_model_access_error_type_for_object( @@ -2093,7 +2145,11 @@ class ProxyErrorTypes(str, enum.Enum): return cls.user_model_access_denied -DB_CONNECTION_ERROR_TYPES = (httpx.ConnectError, httpx.ReadError, httpx.ReadTimeout) +DB_CONNECTION_ERROR_TYPES = ( + httpx.ConnectError, + httpx.ReadError, + httpx.ReadTimeout, +) class SSOUserDefinedValues(TypedDict): diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 80cfb03de4..efbfe8d90c 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -11,7 +11,6 @@ Run checks for: import asyncio import re import time -import traceback from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast from fastapi import Request, status @@ -23,7 +22,6 @@ from litellm.caching.caching import DualCache from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.proxy._types import ( - DB_CONNECTION_ERROR_TYPES, RBAC_ROLES, CallInfo, LiteLLM_EndUserTable, @@ -45,7 +43,6 @@ from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.route_llm_request import route_request from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics from litellm.router import Router -from litellm.types.services import ServiceTypes from .auth_checks_organization import organization_role_based_access_check @@ -987,75 +984,34 @@ async def get_key_object( ) # else, check db - try: - _valid_token: Optional[BaseModel] = await prisma_client.get_data( - token=hashed_token, - table_name="combined_view", - parent_otel_span=parent_otel_span, - proxy_logging_obj=proxy_logging_obj, - ) - - if _valid_token is None: - raise Exception - - _response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True)) - - # save the key object to cache - await _cache_key_object( - hashed_token=hashed_token, - user_api_key_obj=_response, - user_api_key_cache=user_api_key_cache, - proxy_logging_obj=proxy_logging_obj, - ) - - return _response - except DB_CONNECTION_ERROR_TYPES as e: - return await _handle_failed_db_connection_for_get_key_object(e=e) - except Exception: - traceback.print_exc() - raise Exception( - f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call." - ) - - -async def _handle_failed_db_connection_for_get_key_object( - e: Exception, -) -> UserAPIKeyAuth: - """ - Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB - - Use this if you don't want failed DB queries to block LLM API reqiests - - Returns: - - UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True - - Raises: - - Orignal Exception in all other cases - """ - from litellm.proxy.proxy_server import ( - general_settings, - litellm_proxy_admin_name, - proxy_logging_obj, + _valid_token: Optional[BaseModel] = await prisma_client.get_data( + token=hashed_token, + table_name="combined_view", + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) - # If this flag is on, requests failing to connect to the DB will be allowed - if general_settings.get("allow_requests_on_db_unavailable", False) is True: - # log this as a DB failure on prometheus - proxy_logging_obj.service_logging_obj.service_failure_hook( - service=ServiceTypes.DB, - call_type="get_key_object", - error=e, - duration=0.0, + if _valid_token is None: + raise ProxyException( + message="Authentication Error, Invalid proxy server token passed. key={}, not found in db. Create key via `/key/generate` call.".format( + hashed_token + ), + type=ProxyErrorTypes.token_not_found_in_db, + param="key", + code=status.HTTP_401_UNAUTHORIZED, ) - return UserAPIKeyAuth( - key_name="failed-to-connect-to-db", - token="failed-to-connect-to-db", - user_id=litellm_proxy_admin_name, - ) - else: - # raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus - raise e + _response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True)) + + # save the key object to cache + await _cache_key_object( + hashed_token=hashed_token, + user_api_key_obj=_response, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + return _response @log_db_metrics diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py new file mode 100644 index 0000000000..c1a546b569 --- /dev/null +++ b/litellm/proxy/auth/auth_exception_handler.py @@ -0,0 +1,153 @@ +""" +Handles Authentication Errors +""" + +import asyncio +from typing import TYPE_CHECKING, Any, Optional + +from fastapi import HTTPException, Request, status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + DB_CONNECTION_ERROR_TYPES, + ProxyErrorTypes, + ProxyException, + UserAPIKeyAuth, +) +from litellm.proxy.auth.auth_utils import _get_request_ip_address +from litellm.types.services import ServiceTypes + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class UserAPIKeyAuthExceptionHandler: + + @staticmethod + async def _handle_authentication_error( + e: Exception, + request: Request, + request_data: dict, + route: str, + parent_otel_span: Optional[Span], + api_key: str, + ) -> UserAPIKeyAuth: + """ + Handles Connection Errors when reading a Virtual Key from LiteLLM DB + Use this if you don't want failed DB queries to block LLM API reqiests + + Reliability scenarios this covers: + - DB is down and having an outage + - Unable to read / recover a key from the DB + + Returns: + - UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True + + Raises: + - Orignal Exception in all other cases + """ + from litellm.proxy.proxy_server import ( + general_settings, + litellm_proxy_admin_name, + proxy_logging_obj, + ) + + if ( + UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable() + and UserAPIKeyAuthExceptionHandler.is_database_connection_error(e) + ): + # log this as a DB failure on prometheus + proxy_logging_obj.service_logging_obj.service_failure_hook( + service=ServiceTypes.DB, + call_type="get_key_object", + error=e, + duration=0.0, + ) + + return UserAPIKeyAuth( + key_name="failed-to-connect-to-db", + token="failed-to-connect-to-db", + user_id=litellm_proxy_admin_name, + ) + else: + # raise the exception to the caller + requester_ip = _get_request_ip_address( + request=request, + use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format( + str(e), + requester_ip, + ), + extra={"requester_ip": requester_ip}, + ) + + # Log this exception to OTEL, Datadog etc + user_api_key_dict = UserAPIKeyAuth( + parent_otel_span=parent_otel_span, + api_key=api_key, + ) + asyncio.create_task( + proxy_logging_obj.post_call_failure_hook( + request_data=request_data, + original_exception=e, + user_api_key_dict=user_api_key_dict, + error_type=ProxyErrorTypes.auth_error, + route=route, + ) + ) + + if isinstance(e, litellm.BudgetExceededError): + raise ProxyException( + message=e.message, + type=ProxyErrorTypes.budget_exceeded, + param=None, + code=400, + ) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_401_UNAUTHORIZED, + ) + + @staticmethod + def should_allow_request_on_db_unavailable() -> bool: + """ + Returns True if the request should be allowed to proceed despite the DB connection error + """ + from litellm.proxy.proxy_server import general_settings + + if general_settings.get("allow_requests_on_db_unavailable", False) is True: + return True + return False + + @staticmethod + def is_database_connection_error(e: Exception) -> bool: + """ + Returns True if the exception is from a database outage / connection error + """ + import prisma + + if isinstance(e, DB_CONNECTION_ERROR_TYPES): + return True + if isinstance(e, prisma.errors.PrismaError): + return True + if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection: + return True + return False diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index b78619ae65..b58353bf05 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -26,7 +26,6 @@ from litellm.proxy._types import * from litellm.proxy.auth.auth_checks import ( _cache_key_object, _get_user_role, - _handle_failed_db_connection_for_get_key_object, _is_user_proxy_admin, _virtual_key_max_budget_check, _virtual_key_soft_budget_check, @@ -38,8 +37,8 @@ from litellm.proxy.auth.auth_checks import ( get_user_object, is_valid_fallback_model, ) +from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler from litellm.proxy.auth.auth_utils import ( - _get_request_ip_address, get_end_user_id_from_request_body, get_request_route, is_pass_through_provider_route, @@ -675,8 +674,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 if ( prisma_client is None ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error - return await _handle_failed_db_connection_for_get_key_object( - e=Exception("No connected db.") + raise ProxyException( + message="No connected db.", + type=ProxyErrorTypes.no_db_connection, + code=400, + param=None, ) ## check for cache hit (In-Memory Cache) @@ -685,37 +687,25 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 api_key = hash_token(token=api_key) if valid_token is None: - try: - valid_token = await get_key_object( - hashed_token=api_key, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - parent_otel_span=parent_otel_span, - proxy_logging_obj=proxy_logging_obj, - ) - # update end-user params on valid token - # These can change per request - it's important to update them here - valid_token.end_user_id = end_user_params.get("end_user_id") - valid_token.end_user_tpm_limit = end_user_params.get( - "end_user_tpm_limit" - ) - valid_token.end_user_rpm_limit = end_user_params.get( - "end_user_rpm_limit" - ) - valid_token.allowed_model_region = end_user_params.get( - "allowed_model_region" - ) - # update key budget with temp budget increase - valid_token = _update_key_budget_with_temp_budget_increase( - valid_token - ) # updating it here, allows all downstream reporting / checks to use the updated budget - except Exception: - verbose_logger.info( - "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format( - api_key - ) - ) - valid_token = None + valid_token = await get_key_object( + hashed_token=api_key, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + # update end-user params on valid token + # These can change per request - it's important to update them here + valid_token.end_user_id = end_user_params.get("end_user_id") + valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit") + valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit") + valid_token.allowed_model_region = end_user_params.get( + "allowed_model_region" + ) + # update key budget with temp budget increase + valid_token = _update_key_budget_with_temp_budget_increase( + valid_token + ) # updating it here, allows all downstream reporting / checks to use the updated budget if valid_token is None: raise Exception( @@ -1015,58 +1005,15 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 route=route, start_time=start_time, ) - else: - raise Exception() except Exception as e: - requester_ip = _get_request_ip_address( + return await UserAPIKeyAuthExceptionHandler._handle_authentication_error( + e=e, request=request, - use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), - ) - verbose_proxy_logger.exception( - "litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format( - str(e), - requester_ip, - ), - extra={"requester_ip": requester_ip}, - ) - - # Log this exception to OTEL, Datadog etc - user_api_key_dict = UserAPIKeyAuth( + request_data=request_data, + route=route, parent_otel_span=parent_otel_span, api_key=api_key, ) - asyncio.create_task( - proxy_logging_obj.post_call_failure_hook( - request_data=request_data, - original_exception=e, - user_api_key_dict=user_api_key_dict, - error_type=ProxyErrorTypes.auth_error, - route=route, - ) - ) - - if isinstance(e, litellm.BudgetExceededError): - raise ProxyException( - message=e.message, - type=ProxyErrorTypes.budget_exceeded, - param=None, - code=400, - ) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type=ProxyErrorTypes.auth_error, - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type=ProxyErrorTypes.auth_error, - param=getattr(e, "param", "None"), - code=status.HTTP_401_UNAUTHORIZED, - ) @tracer.wrap() diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 26ce6cb8f8..4912a35f89 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,37 +1,9 @@ model_list: - - model_name: gpt-3.5-turbo-end-user-test + - model_name: fake-openai-endpoint litellm_params: - model: azure/chatgpt-v-2 - api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - api_version: "2023-05-15" - api_key: os.environ/AZURE_API_KEY + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ - - -mcp_tools: - - name: "get_current_time" - description: "Get the current time" - input_schema: { - "type": "object", - "properties": { - "format": { - "type": "string", - "description": "The format of the time to return", - "enum": ["short"] - } - } - } - handler: "mcp_tools.get_current_time" - - name: "get_current_date" - description: "Get the current date" - input_schema: { - "type": "object", - "properties": { - "format": { - "type": "string", - "description": "The format of the date to return", - "enum": ["short"] - } - } - } - handler: "mcp_tools.get_current_date" +general_settings: + allow_requests_on_db_unavailable: True \ No newline at end of file diff --git a/tests/litellm/proxy/auth/test_auth_exception_handler.py b/tests/litellm/proxy/auth/test_auth_exception_handler.py new file mode 100644 index 0000000000..b44de86e05 --- /dev/null +++ b/tests/litellm/proxy/auth/test_auth_exception_handler.py @@ -0,0 +1,184 @@ +import asyncio +import json +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException, Request, status +from prisma import errors as prisma_errors +from prisma.errors import ( + ClientNotConnectedError, + DataError, + ForeignKeyViolationError, + HTTPClientClosedError, + MissingRequiredValueError, + PrismaError, + RawQueryError, + RecordNotFoundError, + TableNotFoundError, + UniqueViolationError, +) + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ProxyErrorTypes, ProxyException +from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler + + +# Test is_database_connection_error method +@pytest.mark.parametrize( + "prisma_error", + [ + PrismaError(), + DataError(data={"user_facing_error": {"meta": {"table": "test_table"}}}), + UniqueViolationError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + ForeignKeyViolationError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + MissingRequiredValueError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + RawQueryError(data={"user_facing_error": {"meta": {"table": "test_table"}}}), + TableNotFoundError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + RecordNotFoundError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + HTTPClientClosedError(), + ClientNotConnectedError(), + ], +) +def test_is_database_connection_error_prisma_errors(prisma_error): + """ + Test that all Prisma errors are considered database connection errors + """ + handler = UserAPIKeyAuthExceptionHandler() + assert handler.is_database_connection_error(prisma_error) == True + + +def test_is_database_connection_generic_errors(): + """ + Test non-Prisma error cases for database connection checking + """ + handler = UserAPIKeyAuthExceptionHandler() + + # Test with ProxyException (DB connection) + db_proxy_exception = ProxyException( + message="DB Connection Error", + type=ProxyErrorTypes.no_db_connection, + param="test-param", + ) + assert handler.is_database_connection_error(db_proxy_exception) == True + + # Test with non-DB error + regular_exception = Exception("Regular error") + assert handler.is_database_connection_error(regular_exception) == False + + +# Test should_allow_request_on_db_unavailable method +@patch( + "litellm.proxy.proxy_server.general_settings", + {"allow_requests_on_db_unavailable": True}, +) +def test_should_allow_request_on_db_unavailable_true(): + handler = UserAPIKeyAuthExceptionHandler() + assert handler.should_allow_request_on_db_unavailable() == True + + +@patch( + "litellm.proxy.proxy_server.general_settings", + {"allow_requests_on_db_unavailable": False}, +) +def test_should_allow_request_on_db_unavailable_false(): + handler = UserAPIKeyAuthExceptionHandler() + assert handler.should_allow_request_on_db_unavailable() == False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "prisma_error", + [ + PrismaError(), + DataError(data={"user_facing_error": {"meta": {"table": "test_table"}}}), + UniqueViolationError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + ForeignKeyViolationError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + MissingRequiredValueError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + RawQueryError(data={"user_facing_error": {"meta": {"table": "test_table"}}}), + TableNotFoundError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + RecordNotFoundError( + data={"user_facing_error": {"meta": {"table": "test_table"}}} + ), + HTTPClientClosedError(), + ClientNotConnectedError(), + ], +) +async def test_handle_authentication_error_db_unavailable(prisma_error): + handler = UserAPIKeyAuthExceptionHandler() + + # Mock request and other dependencies + mock_request = MagicMock() + mock_request_data = {} + mock_route = "/test" + mock_span = None + mock_api_key = "test-key" + + # Test with DB connection error when requests are allowed + with patch( + "litellm.proxy.proxy_server.general_settings", + {"allow_requests_on_db_unavailable": True}, + ): + result = await handler._handle_authentication_error( + prisma_error, + mock_request, + mock_request_data, + mock_route, + mock_span, + mock_api_key, + ) + assert result.key_name == "failed-to-connect-to-db" + assert result.token == "failed-to-connect-to-db" + + +@pytest.mark.asyncio +async def test_handle_authentication_error_budget_exceeded(): + handler = UserAPIKeyAuthExceptionHandler() + + # Mock request and other dependencies + mock_request = MagicMock() + mock_request_data = {} + mock_route = "/test" + mock_span = None + mock_api_key = "test-key" + + # Test with budget exceeded error + with pytest.raises(ProxyException) as exc_info: + from litellm.exceptions import BudgetExceededError + + budget_error = BudgetExceededError( + message="Budget exceeded", current_cost=100, max_budget=100 + ) + await handler._handle_authentication_error( + budget_error, + mock_request, + mock_request_data, + mock_route, + mock_span, + mock_api_key, + ) + + assert exc_info.value.type == ProxyErrorTypes.budget_exceeded diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index 0eb1a38755..7695306c87 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -13,9 +13,6 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest, litellm import httpx -from litellm.proxy.auth.auth_checks import ( - _handle_failed_db_connection_for_get_key_object, -) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.auth_checks import get_end_user_object from litellm.caching.caching import DualCache @@ -78,36 +75,6 @@ async def test_get_end_user_object(customer_spend, customer_budget): ) -@pytest.mark.asyncio -async def test_handle_failed_db_connection(): - """ - Test cases: - 1. When allow_requests_on_db_unavailable=True -> return UserAPIKeyAuth - 2. When allow_requests_on_db_unavailable=False -> raise original error - """ - from litellm.proxy.proxy_server import general_settings, litellm_proxy_admin_name - - # Test case 1: allow_requests_on_db_unavailable=True - general_settings["allow_requests_on_db_unavailable"] = True - mock_error = httpx.ConnectError("Failed to connect to DB") - - result = await _handle_failed_db_connection_for_get_key_object(e=mock_error) - - assert isinstance(result, UserAPIKeyAuth) - assert result.key_name == "failed-to-connect-to-db" - assert result.token == "failed-to-connect-to-db" - assert result.user_id == litellm_proxy_admin_name - - # Test case 2: allow_requests_on_db_unavailable=False - general_settings["allow_requests_on_db_unavailable"] = False - - with pytest.raises(httpx.ConnectError) as exc_info: - await _handle_failed_db_connection_for_get_key_object(e=mock_error) - print("_handle_failed_db_connection_for_get_key_object got exception", exc_info) - - assert str(exc_info.value) == "Failed to connect to DB" - - @pytest.mark.parametrize( "model, expect_to_work", [