allow proxy to startup on DB unavailable

This commit is contained in:
Ishaan Jaff 2025-03-26 19:50:57 -07:00
parent 497570b2a6
commit 88ef97b9d1
2 changed files with 50 additions and 42 deletions

View file

@ -20,6 +20,7 @@ from litellm.proxy._types import (
WebhookEvent, WebhookEvent,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.proxy.health_check import ( from litellm.proxy.health_check import (
_clean_endpoint_data, _clean_endpoint_data,
_update_litellm_params_for_health_check, _update_litellm_params_for_health_check,
@ -381,20 +382,23 @@ async def _db_health_readiness_check():
global db_health_cache global db_health_cache
# Note - Intentionally don't try/except this so it raises an exception when it fails # Note - Intentionally don't try/except this so it raises an exception when it fails
try:
# if timedelta is less than 2 minutes return DB Status
time_diff = datetime.now() - db_health_cache["last_updated"]
if db_health_cache["status"] != "unknown" and time_diff < timedelta(minutes=2):
return db_health_cache
# if timedelta is less than 2 minutes return DB Status if prisma_client is None:
time_diff = datetime.now() - db_health_cache["last_updated"] db_health_cache = {"status": "disconnected", "last_updated": datetime.now()}
if db_health_cache["status"] != "unknown" and time_diff < timedelta(minutes=2): return db_health_cache
await prisma_client.health_check()
db_health_cache = {"status": "connected", "last_updated": datetime.now()}
return db_health_cache return db_health_cache
except Exception as e:
if prisma_client is None: PrismaDBExceptionHandler.handle_db_exception(e)
db_health_cache = {"status": "disconnected", "last_updated": datetime.now()}
return db_health_cache return db_health_cache
await prisma_client.health_check()
db_health_cache = {"status": "connected", "last_updated": datetime.now()}
return db_health_cache
@router.get( @router.get(
"/settings", "/settings",

View file

@ -176,6 +176,7 @@ from litellm.proxy.common_utils.proxy_state import ProxyState
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
from litellm.proxy.credential_endpoints.endpoints import router as credential_router from litellm.proxy.credential_endpoints.endpoints import router as credential_router
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
from litellm.proxy.guardrails.guardrail_endpoints import router as guardrails_router from litellm.proxy.guardrails.guardrail_endpoints import router as guardrails_router
@ -456,15 +457,6 @@ async def proxy_startup_event(app: FastAPI):
### LOAD MASTER KEY ### ### LOAD MASTER KEY ###
# check if master key set in environment - load from there # check if master key set in environment - load from there
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
# check if DATABASE_URL in environment - load from there
if prisma_client is None:
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
prisma_client = await ProxyStartupEvent._setup_prisma_client(
database_url=_db_url,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
)
## CHECK PREMIUM USER ## CHECK PREMIUM USER
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format( "litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
@ -527,6 +519,15 @@ async def proxy_startup_event(app: FastAPI):
redis_usage_cache=redis_usage_cache, redis_usage_cache=redis_usage_cache,
) )
# check if DATABASE_URL in environment - load from there
if prisma_client is None:
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
prisma_client = await ProxyStartupEvent._setup_prisma_client(
database_url=_db_url,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
)
## JWT AUTH ## ## JWT AUTH ##
ProxyStartupEvent._initialize_jwt_auth( ProxyStartupEvent._initialize_jwt_auth(
general_settings=general_settings, general_settings=general_settings,
@ -3362,33 +3363,36 @@ class ProxyStartupEvent:
- Sets up prisma client - Sets up prisma client
- Adds necessary views to proxy - Adds necessary views to proxy
""" """
prisma_client: Optional[PrismaClient] = None try:
if database_url is not None: prisma_client: Optional[PrismaClient] = None
try: if database_url is not None:
prisma_client = PrismaClient( try:
database_url=database_url, proxy_logging_obj=proxy_logging_obj prisma_client = PrismaClient(
) database_url=database_url, proxy_logging_obj=proxy_logging_obj
except Exception as e: )
raise e except Exception as e:
raise e
await prisma_client.connect() await prisma_client.connect()
## Add necessary views to proxy ## ## Add necessary views to proxy ##
asyncio.create_task( asyncio.create_task(
prisma_client.check_view_exists() prisma_client.check_view_exists()
) # check if all necessary views exist. Don't block execution ) # check if all necessary views exist. Don't block execution
asyncio.create_task( asyncio.create_task(
prisma_client._set_spend_logs_row_count_in_proxy_state() prisma_client._set_spend_logs_row_count_in_proxy_state()
) # set the spend logs row count in proxy state. Don't block execution ) # set the spend logs row count in proxy state. Don't block execution
# run a health check to ensure the DB is ready # run a health check to ensure the DB is ready
if ( if (
get_secret_bool("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", False) get_secret_bool("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", False)
is not True is not True
): ):
await prisma_client.health_check() await prisma_client.health_check()
return prisma_client return prisma_client
except Exception as e:
PrismaDBExceptionHandler.handle_db_exception(e)
@classmethod @classmethod
def _init_dd_tracer(cls): def _init_dd_tracer(cls):