(feat) Allow failed DB connection requests to allow virtual keys with allow_failed_db_requests (#6605)

* fix use helper for _handle_failed_db_connection_for_get_key_object

* track ALLOW_FAILED_DB_REQUESTS on prometheus

* fix allow_failed_db_requests check

* fix allow_requests_on_db_unavailable

* fix allow_requests_on_db_unavailable

* docs allow_requests_on_db_unavailable

* identify user_id as litellm_proxy_admin_name when DB is failing

* test_handle_failed_db_connection

* fix test_user_api_key_auth_db_unavailable

* update best practices for prod doc

* update best practices for prod

* fix handle db failure
This commit is contained in:
Ishaan Jaff 2024-11-06 20:04:41 -08:00 committed by GitHub
parent eb171e6d95
commit e3519aa5ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 224 additions and 3 deletions

View file

@ -692,9 +692,13 @@ general_settings:
allowed_routes: ["route1", "route2"] # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)
key_management_system: google_kms # either google_kms or azure_kms
master_key: string
# Database Settings
database_url: string
database_connection_pool_limit: 0 # default 100
database_connection_timeout: 0 # default 60s
allow_requests_on_db_unavailable: boolean # if true, will allow requests that can not connect to the DB to verify Virtual Key to still work
custom_auth: string
max_parallel_requests: 0 # the max parallel requests allowed per deployment
global_max_parallel_requests: 0 # the max parallel requests allowed on the proxy all up
@ -766,6 +770,7 @@ general_settings:
| database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) |
| database_connection_pool_limit | integer | The limit for database connection pool [Setting DB Connection Pool limit](#configure-db-pool-limits--connection-timeouts) |
| database_connection_timeout | integer | The timeout for database connections in seconds [Setting DB Connection Pool limit, timeout](#configure-db-pool-limits--connection-timeouts) |
| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key |
| custom_auth | string | Write your own custom authentication logic [Doc Custom Auth](virtual_keys#custom-auth) |
| max_parallel_requests | integer | The max parallel requests allowed per deployment |
| global_max_parallel_requests | integer | The max parallel requests allowed on the proxy overall |

View file

@ -20,6 +20,10 @@ general_settings:
proxy_batch_write_at: 60 # Batch write spend updates every 60s
database_connection_pool_limit: 10 # limit the number of database connections to = MAX Number of DB Connections/Number of instances of litellm proxy (Around 10-20 is good number)
# OPTIONAL Best Practices
disable_spend_logs: True # turn off writing each transaction to the db. We recommend doing this is you don't need to see Usage on the LiteLLM UI and are tracking metrics via Prometheus
allow_requests_on_db_unavailable: True # Only USE when running LiteLLM on your VPC. Allow requests to still be processed even if the DB is unavailable. We recommend doing this if you're running LiteLLM on VPC that cannot be accessed from the public internet.
litellm_settings:
request_timeout: 600 # raise Timeout error if call takes longer than 600 seconds. Default value is 6000seconds if not set
set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on
@ -86,7 +90,29 @@ Set `export LITELLM_MODE="PRODUCTION"`
This disables the load_dotenv() functionality, which will automatically load your environment credentials from the local `.env`.
## 5. Set LiteLLM Salt Key
## 5. If running LiteLLM on VPC, gracefully handle DB unavailability
This will allow LiteLLM to continue to process requests even if the DB is unavailable. This is better handling for DB unavailability.
**WARNING: Only do this if you're running LiteLLM on VPC, that cannot be accessed from the public internet.**
```yaml
general_settings:
allow_requests_on_db_unavailable: True
```
## 6. Disable spend_logs if you're not using the LiteLLM UI
By default LiteLLM will write every request to the `LiteLLM_SpendLogs` table. This is used for viewing Usage on the LiteLLM UI.
If you're not viewing Usage on the LiteLLM UI (most users use Prometheus when this is disabled), you can disable spend_logs by setting `disable_spend_logs` to `True`.
```yaml
general_settings:
disable_spend_logs: True
```
## 7. Set LiteLLM Salt Key
If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB.

View file

@ -13,6 +13,7 @@ import traceback
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Literal, Optional
import httpx
from pydantic import BaseModel
import litellm
@ -717,12 +718,54 @@ async def get_key_object(
)
return _response
except Exception:
except httpx.ConnectError as e:
return await _handle_failed_db_connection_for_get_key_object(e=e)
except Exception as e:
raise Exception(
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
)
async def _handle_failed_db_connection_for_get_key_object(
e: Exception,
) -> UserAPIKeyAuth:
"""
Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB
Use this if you don't want failed DB queries to block LLM API reqiests
Returns:
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
Raises:
- Orignal Exception in all other cases
"""
from litellm.proxy.proxy_server import (
general_settings,
litellm_proxy_admin_name,
proxy_logging_obj,
)
# If this flag is on, requests failing to connect to the DB will be allowed
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
# log this as a DB failure on prometheus
proxy_logging_obj.service_logging_obj.service_failure_hook(
service=ServiceTypes.DB,
call_type="get_key_object",
error=e,
duration=0.0,
)
return UserAPIKeyAuth(
key_name="failed-to-connect-to-db",
token="failed-to-connect-to-db",
user_id=litellm_proxy_admin_name,
)
else:
# raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus
raise e
@log_to_opentelemetry
async def get_org_object(
org_id: str,

View file

@ -5,5 +5,12 @@ model_list:
api_key: os.environ/OPENAI_API_KEY
api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
callbacks: ["gcs_bucket"]
callbacks: ["prometheus"]
service_callback: ["prometheus_system"]
general_settings:
allow_requests_on_db_unavailable: true

View file

@ -12,6 +12,11 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, litellm
import httpx
from litellm.proxy.auth.auth_checks import (
_handle_failed_db_connection_for_get_key_object,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_checks import get_end_user_object
from litellm.caching.caching import DualCache
from litellm.proxy._types import LiteLLM_EndUserTable, LiteLLM_BudgetTable
@ -60,3 +65,33 @@ async def test_get_end_user_object(customer_spend, customer_budget):
customer_spend, customer_budget, str(e)
)
)
@pytest.mark.asyncio
async def test_handle_failed_db_connection():
"""
Test cases:
1. When allow_requests_on_db_unavailable=True -> return UserAPIKeyAuth
2. When allow_requests_on_db_unavailable=False -> raise original error
"""
from litellm.proxy.proxy_server import general_settings, litellm_proxy_admin_name
# Test case 1: allow_requests_on_db_unavailable=True
general_settings["allow_requests_on_db_unavailable"] = True
mock_error = httpx.ConnectError("Failed to connect to DB")
result = await _handle_failed_db_connection_for_get_key_object(e=mock_error)
assert isinstance(result, UserAPIKeyAuth)
assert result.key_name == "failed-to-connect-to-db"
assert result.token == "failed-to-connect-to-db"
assert result.user_id == litellm_proxy_admin_name
# Test case 2: allow_requests_on_db_unavailable=False
general_settings["allow_requests_on_db_unavailable"] = False
with pytest.raises(httpx.ConnectError) as exc_info:
await _handle_failed_db_connection_for_get_key_object(e=mock_error)
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
assert str(exc_info.value) == "Failed to connect to DB"

View file

@ -28,6 +28,7 @@ from datetime import datetime
from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
import httpx
load_dotenv()
import io
@ -51,6 +52,7 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import (
user_info,
user_update,
)
from litellm.proxy.auth.auth_checks import get_key_object
from litellm.proxy.management_endpoints.key_management_endpoints import (
delete_key_fn,
generate_key_fn,
@ -3307,3 +3309,106 @@ async def test_service_accounts(prisma_client):
print("response from user_api_key_auth", result)
setattr(litellm.proxy.proxy_server, "general_settings", {})
@pytest.mark.asyncio
async def test_user_api_key_auth_db_unavailable():
"""
Test that user_api_key_auth handles DB connection failures appropriately when:
1. DB connection fails during token validation
2. allow_requests_on_db_unavailable=True
"""
litellm.set_verbose = True
# Mock dependencies
class MockPrismaClient:
async def get_data(self, *args, **kwargs):
print("MockPrismaClient.get_data() called")
raise httpx.ConnectError("Failed to connect to DB")
async def connect(self):
print("MockPrismaClient.connect() called")
pass
class MockDualCache:
async def async_get_cache(self, *args, **kwargs):
return None
async def async_set_cache(self, *args, **kwargs):
pass
async def set_cache(self, *args, **kwargs):
pass
# Set up test environment
setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient())
setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache())
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(
litellm.proxy.proxy_server,
"general_settings",
{"allow_requests_on_db_unavailable": True},
)
# Create test request
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# Run test with a sample API key
result = await user_api_key_auth(
request=request,
api_key="Bearer sk-123456789",
)
# Verify results
assert isinstance(result, UserAPIKeyAuth)
assert result.key_name == "failed-to-connect-to-db"
assert result.user_id == litellm.proxy.proxy_server.litellm_proxy_admin_name
@pytest.mark.asyncio
async def test_user_api_key_auth_db_unavailable_not_allowed():
"""
Test that user_api_key_auth raises an exception when:
This is default behavior
1. DB connection fails during token validation
2. allow_requests_on_db_unavailable=False (default behavior)
"""
# Mock dependencies
class MockPrismaClient:
async def get_data(self, *args, **kwargs):
print("MockPrismaClient.get_data() called")
raise httpx.ConnectError("Failed to connect to DB")
async def connect(self):
print("MockPrismaClient.connect() called")
pass
class MockDualCache:
async def async_get_cache(self, *args, **kwargs):
return None
async def async_set_cache(self, *args, **kwargs):
pass
async def set_cache(self, *args, **kwargs):
pass
# Set up test environment
setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient())
setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache())
setattr(litellm.proxy.proxy_server, "general_settings", {})
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
# Create test request
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# Run test with a sample API key
with pytest.raises(litellm.proxy._types.ProxyException):
await user_api_key_auth(
request=request,
api_key="Bearer sk-123456789",
)