forked from phoenix/litellm-mirror
(feat) Allow failed DB connection requests to allow virtual keys with allow_failed_db_requests
(#6605)
* fix use helper for _handle_failed_db_connection_for_get_key_object * track ALLOW_FAILED_DB_REQUESTS on prometheus * fix allow_failed_db_requests check * fix allow_requests_on_db_unavailable * fix allow_requests_on_db_unavailable * docs allow_requests_on_db_unavailable * identify user_id as litellm_proxy_admin_name when DB is failing * test_handle_failed_db_connection * fix test_user_api_key_auth_db_unavailable * update best practices for prod doc * update best practices for prod * fix handle db failure
This commit is contained in:
parent
eb171e6d95
commit
e3519aa5ae
6 changed files with 224 additions and 3 deletions
|
@ -692,9 +692,13 @@ general_settings:
|
|||
allowed_routes: ["route1", "route2"] # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)
|
||||
key_management_system: google_kms # either google_kms or azure_kms
|
||||
master_key: string
|
||||
|
||||
# Database Settings
|
||||
database_url: string
|
||||
database_connection_pool_limit: 0 # default 100
|
||||
database_connection_timeout: 0 # default 60s
|
||||
allow_requests_on_db_unavailable: boolean # if true, will allow requests that can not connect to the DB to verify Virtual Key to still work
|
||||
|
||||
custom_auth: string
|
||||
max_parallel_requests: 0 # the max parallel requests allowed per deployment
|
||||
global_max_parallel_requests: 0 # the max parallel requests allowed on the proxy all up
|
||||
|
@ -766,6 +770,7 @@ general_settings:
|
|||
| database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) |
|
||||
| database_connection_pool_limit | integer | The limit for database connection pool [Setting DB Connection Pool limit](#configure-db-pool-limits--connection-timeouts) |
|
||||
| database_connection_timeout | integer | The timeout for database connections in seconds [Setting DB Connection Pool limit, timeout](#configure-db-pool-limits--connection-timeouts) |
|
||||
| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key |
|
||||
| custom_auth | string | Write your own custom authentication logic [Doc Custom Auth](virtual_keys#custom-auth) |
|
||||
| max_parallel_requests | integer | The max parallel requests allowed per deployment |
|
||||
| global_max_parallel_requests | integer | The max parallel requests allowed on the proxy overall |
|
||||
|
|
|
@ -20,6 +20,10 @@ general_settings:
|
|||
proxy_batch_write_at: 60 # Batch write spend updates every 60s
|
||||
database_connection_pool_limit: 10 # limit the number of database connections to = MAX Number of DB Connections/Number of instances of litellm proxy (Around 10-20 is good number)
|
||||
|
||||
# OPTIONAL Best Practices
|
||||
disable_spend_logs: True # turn off writing each transaction to the db. We recommend doing this is you don't need to see Usage on the LiteLLM UI and are tracking metrics via Prometheus
|
||||
allow_requests_on_db_unavailable: True # Only USE when running LiteLLM on your VPC. Allow requests to still be processed even if the DB is unavailable. We recommend doing this if you're running LiteLLM on VPC that cannot be accessed from the public internet.
|
||||
|
||||
litellm_settings:
|
||||
request_timeout: 600 # raise Timeout error if call takes longer than 600 seconds. Default value is 6000seconds if not set
|
||||
set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on
|
||||
|
@ -86,7 +90,29 @@ Set `export LITELLM_MODE="PRODUCTION"`
|
|||
|
||||
This disables the load_dotenv() functionality, which will automatically load your environment credentials from the local `.env`.
|
||||
|
||||
## 5. Set LiteLLM Salt Key
|
||||
## 5. If running LiteLLM on VPC, gracefully handle DB unavailability
|
||||
|
||||
This will allow LiteLLM to continue to process requests even if the DB is unavailable. This is better handling for DB unavailability.
|
||||
|
||||
**WARNING: Only do this if you're running LiteLLM on VPC, that cannot be accessed from the public internet.**
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
allow_requests_on_db_unavailable: True
|
||||
```
|
||||
|
||||
## 6. Disable spend_logs if you're not using the LiteLLM UI
|
||||
|
||||
By default LiteLLM will write every request to the `LiteLLM_SpendLogs` table. This is used for viewing Usage on the LiteLLM UI.
|
||||
|
||||
If you're not viewing Usage on the LiteLLM UI (most users use Prometheus when this is disabled), you can disable spend_logs by setting `disable_spend_logs` to `True`.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_spend_logs: True
|
||||
```
|
||||
|
||||
## 7. Set LiteLLM Salt Key
|
||||
|
||||
If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB.
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import traceback
|
|||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
|
@ -717,12 +718,54 @@ async def get_key_object(
|
|||
)
|
||||
|
||||
return _response
|
||||
except Exception:
|
||||
except httpx.ConnectError as e:
|
||||
return await _handle_failed_db_connection_for_get_key_object(e=e)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
|
||||
)
|
||||
|
||||
|
||||
async def _handle_failed_db_connection_for_get_key_object(
|
||||
e: Exception,
|
||||
) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB
|
||||
|
||||
Use this if you don't want failed DB queries to block LLM API reqiests
|
||||
|
||||
Returns:
|
||||
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
|
||||
|
||||
Raises:
|
||||
- Orignal Exception in all other cases
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
litellm_proxy_admin_name,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
|
||||
# If this flag is on, requests failing to connect to the DB will be allowed
|
||||
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
|
||||
# log this as a DB failure on prometheus
|
||||
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type="get_key_object",
|
||||
error=e,
|
||||
duration=0.0,
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
key_name="failed-to-connect-to-db",
|
||||
token="failed-to-connect-to-db",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
)
|
||||
else:
|
||||
# raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus
|
||||
raise e
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def get_org_object(
|
||||
org_id: str,
|
||||
|
|
|
@ -5,5 +5,12 @@ model_list:
|
|||
api_key: os.environ/OPENAI_API_KEY
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["gcs_bucket"]
|
||||
callbacks: ["prometheus"]
|
||||
service_callback: ["prometheus_system"]
|
||||
|
||||
|
||||
general_settings:
|
||||
allow_requests_on_db_unavailable: true
|
||||
|
||||
|
|
|
@ -12,6 +12,11 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest, litellm
|
||||
import httpx
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_handle_failed_db_connection_for_get_key_object,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_checks import get_end_user_object
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import LiteLLM_EndUserTable, LiteLLM_BudgetTable
|
||||
|
@ -60,3 +65,33 @@ async def test_get_end_user_object(customer_spend, customer_budget):
|
|||
customer_spend, customer_budget, str(e)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_failed_db_connection():
|
||||
"""
|
||||
Test cases:
|
||||
1. When allow_requests_on_db_unavailable=True -> return UserAPIKeyAuth
|
||||
2. When allow_requests_on_db_unavailable=False -> raise original error
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, litellm_proxy_admin_name
|
||||
|
||||
# Test case 1: allow_requests_on_db_unavailable=True
|
||||
general_settings["allow_requests_on_db_unavailable"] = True
|
||||
mock_error = httpx.ConnectError("Failed to connect to DB")
|
||||
|
||||
result = await _handle_failed_db_connection_for_get_key_object(e=mock_error)
|
||||
|
||||
assert isinstance(result, UserAPIKeyAuth)
|
||||
assert result.key_name == "failed-to-connect-to-db"
|
||||
assert result.token == "failed-to-connect-to-db"
|
||||
assert result.user_id == litellm_proxy_admin_name
|
||||
|
||||
# Test case 2: allow_requests_on_db_unavailable=False
|
||||
general_settings["allow_requests_on_db_unavailable"] = False
|
||||
|
||||
with pytest.raises(httpx.ConnectError) as exc_info:
|
||||
await _handle_failed_db_connection_for_get_key_object(e=mock_error)
|
||||
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
|
||||
|
||||
assert str(exc_info.value) == "Failed to connect to DB"
|
||||
|
|
|
@ -28,6 +28,7 @@ from datetime import datetime
|
|||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from fastapi.routing import APIRoute
|
||||
import httpx
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
|
@ -51,6 +52,7 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
|||
user_info,
|
||||
user_update,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import get_key_object
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_key_fn,
|
||||
generate_key_fn,
|
||||
|
@ -3307,3 +3309,106 @@ async def test_service_accounts(prisma_client):
|
|||
print("response from user_api_key_auth", result)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_api_key_auth_db_unavailable():
|
||||
"""
|
||||
Test that user_api_key_auth handles DB connection failures appropriately when:
|
||||
1. DB connection fails during token validation
|
||||
2. allow_requests_on_db_unavailable=True
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Mock dependencies
|
||||
class MockPrismaClient:
|
||||
async def get_data(self, *args, **kwargs):
|
||||
print("MockPrismaClient.get_data() called")
|
||||
raise httpx.ConnectError("Failed to connect to DB")
|
||||
|
||||
async def connect(self):
|
||||
print("MockPrismaClient.connect() called")
|
||||
pass
|
||||
|
||||
class MockDualCache:
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def async_set_cache(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def set_cache(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
# Set up test environment
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient())
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache())
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"general_settings",
|
||||
{"allow_requests_on_db_unavailable": True},
|
||||
)
|
||||
|
||||
# Create test request
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
# Run test with a sample API key
|
||||
result = await user_api_key_auth(
|
||||
request=request,
|
||||
api_key="Bearer sk-123456789",
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert isinstance(result, UserAPIKeyAuth)
|
||||
assert result.key_name == "failed-to-connect-to-db"
|
||||
assert result.user_id == litellm.proxy.proxy_server.litellm_proxy_admin_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_api_key_auth_db_unavailable_not_allowed():
|
||||
"""
|
||||
Test that user_api_key_auth raises an exception when:
|
||||
This is default behavior
|
||||
|
||||
1. DB connection fails during token validation
|
||||
2. allow_requests_on_db_unavailable=False (default behavior)
|
||||
"""
|
||||
|
||||
# Mock dependencies
|
||||
class MockPrismaClient:
|
||||
async def get_data(self, *args, **kwargs):
|
||||
print("MockPrismaClient.get_data() called")
|
||||
raise httpx.ConnectError("Failed to connect to DB")
|
||||
|
||||
async def connect(self):
|
||||
print("MockPrismaClient.connect() called")
|
||||
pass
|
||||
|
||||
class MockDualCache:
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def async_set_cache(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def set_cache(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
# Set up test environment
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient())
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache())
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {})
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
|
||||
# Create test request
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
# Run test with a sample API key
|
||||
with pytest.raises(litellm.proxy._types.ProxyException):
|
||||
await user_api_key_auth(
|
||||
request=request,
|
||||
api_key="Bearer sk-123456789",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue