diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e0d2b195e6..c81da86452 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1156,7 +1156,7 @@ async def startup_event(): ### CONNECT TO DB ### # check if DATABASE_URL in environment - load from there - if os.getenv("DATABASE_URL", None) is not None and prisma_client is None: + if prisma_client is None: prisma_setup(database_url=os.getenv("DATABASE_URL")) ### LOAD CONFIG ### @@ -1184,7 +1184,7 @@ async def startup_event(): ) # start the background health check coroutine. verbose_proxy_logger.debug(f"prisma client - {prisma_client}") - if prisma_client: + if prisma_client is not None: await prisma_client.connect() if prisma_client is not None and master_key is not None: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 3412b97a41..798c02b647 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -555,7 +555,8 @@ class PrismaClient: ) async def connect(self): try: - await self.db.connect() + if self.db.is_connected() == False: + await self.db.connect() except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) diff --git a/litellm/tests/test_proxy_startup.py b/litellm/tests/test_proxy_startup.py index c6f3f55396..1183e8c616 100644 --- a/litellm/tests/test_proxy_startup.py +++ b/litellm/tests/test_proxy_startup.py @@ -19,7 +19,8 @@ from litellm.proxy.proxy_server import ( save_worker_config, initialize, startup_event, - llm_model_list + llm_model_list, + shutdown_event ) def test_proxy_gunicorn_startup_direct_config(): @@ -36,6 +37,7 @@ def test_proxy_gunicorn_startup_direct_config(): config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" os.environ["WORKER_CONFIG"] = config_fp asyncio.run(startup_event()) + asyncio.run(shutdown_event()) except Exception as e: if "Already connected to the query engine" in str(e): pass @@ -51,6 +53,7 @@ def test_proxy_gunicorn_startup_config_dict(): worker_config = {"config": config_fp} os.environ["WORKER_CONFIG"] = json.dumps(worker_config) asyncio.run(startup_event()) + asyncio.run(shutdown_event()) except Exception as e: if "Already connected to the query engine" in str(e): pass