fix(proxy_server.py): don't reconnect prisma if already connected

This commit is contained in:
Krrish Dholakia 2024-01-09 11:38:42 +05:30 committed by ishaan-jaff
parent 9673e6042e
commit 27e52794df
3 changed files with 16 additions and 3 deletions

View file

@ -1150,6 +1150,15 @@ async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings
import json
### LOAD MASTER KEY ###
# check if master key set in environment - load from there
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
### CONNECT TO DB ###
# check if DATABASE_URL in environment - load from there
if prisma_client is None:
prisma_setup(database_url=os.getenv("DATABASE_URL"))
### LOAD CONFIG ###
worker_config = litellm.get_secret("WORKER_CONFIG")
verbose_proxy_logger.debug(f"worker_config: {worker_config}")
@ -1184,7 +1193,7 @@ async def startup_event():
prisma_setup(database_url=os.getenv("DATABASE_URL"))
verbose_proxy_logger.debug(f"prisma client - {prisma_client}")
if prisma_client:
if prisma_client is not None:
await prisma_client.connect()
if prisma_client is not None and master_key is not None:

View file

@ -578,7 +578,8 @@ class PrismaClient:
)
async def connect(self):
try:
await self.db.connect()
if self.db.is_connected() == False:
await self.db.connect()
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)

View file

@ -19,7 +19,8 @@ from litellm.proxy.proxy_server import (
save_worker_config,
initialize,
startup_event,
llm_model_list
llm_model_list,
shutdown_event
)
def test_proxy_gunicorn_startup_direct_config():
@ -36,6 +37,7 @@ def test_proxy_gunicorn_startup_direct_config():
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
os.environ["WORKER_CONFIG"] = config_fp
asyncio.run(startup_event())
asyncio.run(shutdown_event())
except Exception as e:
if "Already connected to the query engine" in str(e):
pass
@ -51,6 +53,7 @@ def test_proxy_gunicorn_startup_config_dict():
worker_config = {"config": config_fp}
os.environ["WORKER_CONFIG"] = json.dumps(worker_config)
asyncio.run(startup_event())
asyncio.run(shutdown_event())
except Exception as e:
if "Already connected to the query engine" in str(e):
pass