From 8bf0005012cc81f9496c59733ac1197f8cc02c36 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 5 Nov 2024 00:43:32 +0530 Subject: [PATCH] (proxy fix) - call connect on prisma client when running setup (#6534) * critical fix - call connect on prisma client when running setup * fix test_proxy_server_prisma_setup * fix test_proxy_server_prisma_setup --- litellm/proxy/proxy_server.py | 9 ++++----- tests/local_testing/test_proxy_server.py | 6 ++++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ca6befef6..363ab4efd 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2993,7 +2993,7 @@ class ProxyStartupEvent: scheduler.start() @classmethod - def _setup_prisma_client( + async def _setup_prisma_client( cls, database_url: Optional[str], proxy_logging_obj: ProxyLogging, @@ -3012,6 +3012,8 @@ class ProxyStartupEvent: except Exception as e: raise e + await prisma_client.connect() + ## Add necessary views to proxy ## asyncio.create_task( prisma_client.check_view_exists() @@ -3033,7 +3035,7 @@ async def startup_event(): # check if DATABASE_URL in environment - load from there if prisma_client is None: _db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore - prisma_client = ProxyStartupEvent._setup_prisma_client( + prisma_client = await ProxyStartupEvent._setup_prisma_client( database_url=_db_url, proxy_logging_obj=proxy_logging_obj, user_api_key_cache=user_api_key_cache, @@ -3123,9 +3125,6 @@ async def startup_event(): prompt_injection_detection_obj.update_environment(router=llm_router) verbose_proxy_logger.debug("prisma_client: %s", prisma_client) - if prisma_client is not None: - await prisma_client.connect() - if prisma_client is not None and master_key is not None: ProxyStartupEvent._add_master_key_hash_to_db( master_key=master_key, diff --git a/tests/local_testing/test_proxy_server.py b/tests/local_testing/test_proxy_server.py index 51ec085ba..808b10db3 100644 --- a/tests/local_testing/test_proxy_server.py +++ b/tests/local_testing/test_proxy_server.py @@ -1909,13 +1909,15 @@ async def test_proxy_server_prisma_setup(): litellm.proxy.proxy_server, "PrismaClient", new=MagicMock() ) as mock_prisma_client: mock_client = mock_prisma_client.return_value # This is the mocked instance + mock_client.connect = AsyncMock() # Mock the connect method mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method - ProxyStartupEvent._setup_prisma_client( + await ProxyStartupEvent._setup_prisma_client( database_url=os.getenv("DATABASE_URL"), proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache), user_api_key_cache=user_api_key_cache, ) - await asyncio.sleep(1) + # Verify our mocked methods were called + mock_client.connect.assert_called_once() mock_client.check_view_exists.assert_called_once()