diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 28b0b67e3..1adc4943d 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -692,9 +692,13 @@ general_settings: allowed_routes: ["route1", "route2"] # list of allowed proxy API routes - a user can access. (currently JWT-Auth only) key_management_system: google_kms # either google_kms or azure_kms master_key: string + + # Database Settings database_url: string database_connection_pool_limit: 0 # default 100 database_connection_timeout: 0 # default 60s + allow_requests_on_db_unavailable: boolean # if true, will allow requests that can not connect to the DB to verify Virtual Key to still work + custom_auth: string max_parallel_requests: 0 # the max parallel requests allowed per deployment global_max_parallel_requests: 0 # the max parallel requests allowed on the proxy all up @@ -766,6 +770,7 @@ general_settings: | database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) | | database_connection_pool_limit | integer | The limit for database connection pool [Setting DB Connection Pool limit](#configure-db-pool-limits--connection-timeouts) | | database_connection_timeout | integer | The timeout for database connections in seconds [Setting DB Connection Pool limit, timeout](#configure-db-pool-limits--connection-timeouts) | +| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key | | custom_auth | string | Write your own custom authentication logic [Doc Custom Auth](virtual_keys#custom-auth) | | max_parallel_requests | integer | The max parallel requests allowed per deployment | | global_max_parallel_requests | integer | The max parallel requests allowed on the proxy overall | diff --git a/docs/my-website/docs/proxy/prod.md b/docs/my-website/docs/proxy/prod.md index 99fa19e77..66c719e5d 100644 --- a/docs/my-website/docs/proxy/prod.md +++ b/docs/my-website/docs/proxy/prod.md @@ -20,6 +20,10 @@ general_settings: proxy_batch_write_at: 60 # Batch write spend updates every 60s database_connection_pool_limit: 10 # limit the number of database connections to = MAX Number of DB Connections/Number of instances of litellm proxy (Around 10-20 is good number) + # OPTIONAL Best Practices + disable_spend_logs: True # turn off writing each transaction to the db. We recommend doing this is you don't need to see Usage on the LiteLLM UI and are tracking metrics via Prometheus + allow_requests_on_db_unavailable: True # Only USE when running LiteLLM on your VPC. Allow requests to still be processed even if the DB is unavailable. We recommend doing this if you're running LiteLLM on VPC that cannot be accessed from the public internet. + litellm_settings: request_timeout: 600 # raise Timeout error if call takes longer than 600 seconds. Default value is 6000seconds if not set set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on @@ -86,7 +90,29 @@ Set `export LITELLM_MODE="PRODUCTION"` This disables the load_dotenv() functionality, which will automatically load your environment credentials from the local `.env`. -## 5. Set LiteLLM Salt Key +## 5. If running LiteLLM on VPC, gracefully handle DB unavailability + +This will allow LiteLLM to continue to process requests even if the DB is unavailable. This is better handling for DB unavailability. + +**WARNING: Only do this if you're running LiteLLM on VPC, that cannot be accessed from the public internet.** + +```yaml +general_settings: + allow_requests_on_db_unavailable: True +``` + +## 6. Disable spend_logs if you're not using the LiteLLM UI + +By default LiteLLM will write every request to the `LiteLLM_SpendLogs` table. This is used for viewing Usage on the LiteLLM UI. + +If you're not viewing Usage on the LiteLLM UI (most users use Prometheus when this is disabled), you can disable spend_logs by setting `disable_spend_logs` to `True`. + +```yaml +general_settings: + disable_spend_logs: True +``` + +## 7. Set LiteLLM Salt Key If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB. diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index e00d494d9..dcc1c5e90 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -13,6 +13,7 @@ import traceback from datetime import datetime from typing import TYPE_CHECKING, Any, List, Literal, Optional +import httpx from pydantic import BaseModel import litellm @@ -717,12 +718,54 @@ async def get_key_object( ) return _response - except Exception: + except httpx.ConnectError as e: + return await _handle_failed_db_connection_for_get_key_object(e=e) + except Exception as e: raise Exception( f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call." ) +async def _handle_failed_db_connection_for_get_key_object( + e: Exception, +) -> UserAPIKeyAuth: + """ + Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB + + Use this if you don't want failed DB queries to block LLM API reqiests + + Returns: + - UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True + + Raises: + - Orignal Exception in all other cases + """ + from litellm.proxy.proxy_server import ( + general_settings, + litellm_proxy_admin_name, + proxy_logging_obj, + ) + + # If this flag is on, requests failing to connect to the DB will be allowed + if general_settings.get("allow_requests_on_db_unavailable", False) is True: + # log this as a DB failure on prometheus + proxy_logging_obj.service_logging_obj.service_failure_hook( + service=ServiceTypes.DB, + call_type="get_key_object", + error=e, + duration=0.0, + ) + + return UserAPIKeyAuth( + key_name="failed-to-connect-to-db", + token="failed-to-connect-to-db", + user_id=litellm_proxy_admin_name, + ) + else: + # raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus + raise e + + @log_to_opentelemetry async def get_org_object( org_id: str, diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 9767677cf..694c1613d 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -5,5 +5,12 @@ model_list: api_key: os.environ/OPENAI_API_KEY api_base: https://exampleopenaiendpoint-production.up.railway.app/ + litellm_settings: - callbacks: ["gcs_bucket"] \ No newline at end of file + callbacks: ["prometheus"] + service_callback: ["prometheus_system"] + + +general_settings: + allow_requests_on_db_unavailable: true + diff --git a/tests/local_testing/test_auth_checks.py b/tests/local_testing/test_auth_checks.py index 3ea113c28..f1683a153 100644 --- a/tests/local_testing/test_auth_checks.py +++ b/tests/local_testing/test_auth_checks.py @@ -12,6 +12,11 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest, litellm +import httpx +from litellm.proxy.auth.auth_checks import ( + _handle_failed_db_connection_for_get_key_object, +) +from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.auth_checks import get_end_user_object from litellm.caching.caching import DualCache from litellm.proxy._types import LiteLLM_EndUserTable, LiteLLM_BudgetTable @@ -60,3 +65,33 @@ async def test_get_end_user_object(customer_spend, customer_budget): customer_spend, customer_budget, str(e) ) ) + + +@pytest.mark.asyncio +async def test_handle_failed_db_connection(): + """ + Test cases: + 1. When allow_requests_on_db_unavailable=True -> return UserAPIKeyAuth + 2. When allow_requests_on_db_unavailable=False -> raise original error + """ + from litellm.proxy.proxy_server import general_settings, litellm_proxy_admin_name + + # Test case 1: allow_requests_on_db_unavailable=True + general_settings["allow_requests_on_db_unavailable"] = True + mock_error = httpx.ConnectError("Failed to connect to DB") + + result = await _handle_failed_db_connection_for_get_key_object(e=mock_error) + + assert isinstance(result, UserAPIKeyAuth) + assert result.key_name == "failed-to-connect-to-db" + assert result.token == "failed-to-connect-to-db" + assert result.user_id == litellm_proxy_admin_name + + # Test case 2: allow_requests_on_db_unavailable=False + general_settings["allow_requests_on_db_unavailable"] = False + + with pytest.raises(httpx.ConnectError) as exc_info: + await _handle_failed_db_connection_for_get_key_object(e=mock_error) + print("_handle_failed_db_connection_for_get_key_object got exception", exc_info) + + assert str(exc_info.value) == "Failed to connect to DB" diff --git a/tests/local_testing/test_key_generate_prisma.py b/tests/local_testing/test_key_generate_prisma.py index e009e214c..66b9c7b8f 100644 --- a/tests/local_testing/test_key_generate_prisma.py +++ b/tests/local_testing/test_key_generate_prisma.py @@ -28,6 +28,7 @@ from datetime import datetime from dotenv import load_dotenv from fastapi import Request from fastapi.routing import APIRoute +import httpx load_dotenv() import io @@ -51,6 +52,7 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import ( user_info, user_update, ) +from litellm.proxy.auth.auth_checks import get_key_object from litellm.proxy.management_endpoints.key_management_endpoints import ( delete_key_fn, generate_key_fn, @@ -3307,3 +3309,106 @@ async def test_service_accounts(prisma_client): print("response from user_api_key_auth", result) setattr(litellm.proxy.proxy_server, "general_settings", {}) + + +@pytest.mark.asyncio +async def test_user_api_key_auth_db_unavailable(): + """ + Test that user_api_key_auth handles DB connection failures appropriately when: + 1. DB connection fails during token validation + 2. allow_requests_on_db_unavailable=True + """ + litellm.set_verbose = True + + # Mock dependencies + class MockPrismaClient: + async def get_data(self, *args, **kwargs): + print("MockPrismaClient.get_data() called") + raise httpx.ConnectError("Failed to connect to DB") + + async def connect(self): + print("MockPrismaClient.connect() called") + pass + + class MockDualCache: + async def async_get_cache(self, *args, **kwargs): + return None + + async def async_set_cache(self, *args, **kwargs): + pass + + async def set_cache(self, *args, **kwargs): + pass + + # Set up test environment + setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient()) + setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache()) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr( + litellm.proxy.proxy_server, + "general_settings", + {"allow_requests_on_db_unavailable": True}, + ) + + # Create test request + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + # Run test with a sample API key + result = await user_api_key_auth( + request=request, + api_key="Bearer sk-123456789", + ) + + # Verify results + assert isinstance(result, UserAPIKeyAuth) + assert result.key_name == "failed-to-connect-to-db" + assert result.user_id == litellm.proxy.proxy_server.litellm_proxy_admin_name + + +@pytest.mark.asyncio +async def test_user_api_key_auth_db_unavailable_not_allowed(): + """ + Test that user_api_key_auth raises an exception when: + This is default behavior + + 1. DB connection fails during token validation + 2. allow_requests_on_db_unavailable=False (default behavior) + """ + + # Mock dependencies + class MockPrismaClient: + async def get_data(self, *args, **kwargs): + print("MockPrismaClient.get_data() called") + raise httpx.ConnectError("Failed to connect to DB") + + async def connect(self): + print("MockPrismaClient.connect() called") + pass + + class MockDualCache: + async def async_get_cache(self, *args, **kwargs): + return None + + async def async_set_cache(self, *args, **kwargs): + pass + + async def set_cache(self, *args, **kwargs): + pass + + # Set up test environment + setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient()) + setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache()) + setattr(litellm.proxy.proxy_server, "general_settings", {}) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + + # Create test request + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + # Run test with a sample API key + with pytest.raises(litellm.proxy._types.ProxyException): + await user_api_key_auth( + request=request, + api_key="Bearer sk-123456789", + )