diff --git a/.circleci/config.yml b/.circleci/config.yml index 063aff4c6..4bb5ebc45 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -986,6 +986,41 @@ jobs: - store_test_results: path: test-results + test_bad_database_url: + machine: + image: ubuntu-2204:2023.10.1 + resource_class: xlarge + working_directory: ~/project + steps: + - checkout + - run: + name: Build Docker image + command: | + docker build -t myapp . -f ./docker/Dockerfile.non_root + - run: + name: Run Docker container with bad DATABASE_URL + command: | + docker run --name my-app \ + -p 4000:4000 \ + -e DATABASE_URL="postgresql://wrong:wrong@wrong:5432/wrong" \ + myapp:latest \ + --port 4000 > docker_output.log 2>&1 || true + - run: + name: Display Docker logs + command: cat docker_output.log + - run: + name: Check for expected error + command: | + if grep -q "Error: P1001: Can't reach database server at" docker_output.log && \ + grep -q "httpx.ConnectError: All connection attempts failed" docker_output.log && \ + grep -q "ERROR: Application startup failed. Exiting." docker_output.log; then + echo "Expected error found. Test passed." + else + echo "Expected error not found. Test failed." + cat docker_output.log + exit 1 + fi + workflows: version: 2 build_and_test: @@ -1082,11 +1117,18 @@ workflows: only: - main - /litellm_.*/ + - test_bad_database_url: + filters: + branches: + only: + - main + - /litellm_.*/ - publish_to_pypi: requires: - local_testing - build_and_test - load_testing + - test_bad_database_url - llm_translation_testing - logging_testing - litellm_router_testing diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8edf2cee3..ce58c4d75 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3052,6 +3052,8 @@ class ProxyStartupEvent: prisma_client.check_view_exists() ) # check if all necessary views exist. Don't block execution + # run a health check to ensure the DB is ready + await prisma_client.health_check() return prisma_client diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 44243cab0..44e9d151d 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1083,19 +1083,16 @@ class PrismaClient: proxy_logging_obj: ProxyLogging, http_client: Optional[Any] = None, ): - verbose_proxy_logger.debug( - "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" - ) ## init logging object self.proxy_logging_obj = proxy_logging_obj self.iam_token_db_auth: Optional[bool] = str_to_bool( os.getenv("IAM_TOKEN_DB_AUTH") ) + verbose_proxy_logger.debug("Creating Prisma Client..") try: from prisma import Prisma # type: ignore except Exception: raise Exception("Unable to find Prisma binaries.") - verbose_proxy_logger.debug("Connecting Prisma Client to DB..") if http_client is not None: self.db = PrismaWrapper( original_prisma=Prisma(http=http_client), @@ -1114,7 +1111,7 @@ class PrismaClient: else False ), ) # Client to connect to Prisma db - verbose_proxy_logger.debug("Success - Connected Prisma Client to DB") + verbose_proxy_logger.debug("Success - Created Prisma Client") def hash_token(self, token: str): # Hash the string using SHA-256 @@ -2348,11 +2345,7 @@ class PrismaClient: """ start_time = time.time() try: - sql_query = """ - SELECT 1 - FROM "LiteLLM_VerificationToken" - LIMIT 1 - """ + sql_query = "SELECT 1" # Execute the raw query # The asterisk before `user_id_list` unpacks the list into separate arguments diff --git a/tests/local_testing/test_proxy_server.py b/tests/local_testing/test_proxy_server.py index 808b10db3..76cdf1a54 100644 --- a/tests/local_testing/test_proxy_server.py +++ b/tests/local_testing/test_proxy_server.py @@ -1911,6 +1911,7 @@ async def test_proxy_server_prisma_setup(): mock_client = mock_prisma_client.return_value # This is the mocked instance mock_client.connect = AsyncMock() # Mock the connect method mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method + mock_client.health_check = AsyncMock() # Mock the health_check method await ProxyStartupEvent._setup_prisma_client( database_url=os.getenv("DATABASE_URL"), @@ -1921,3 +1922,41 @@ async def test_proxy_server_prisma_setup(): # Verify our mocked methods were called mock_client.connect.assert_called_once() mock_client.check_view_exists.assert_called_once() + + # Note: This is REALLY IMPORTANT to check that the health check is called + # This is how we ensure the DB is ready before proceeding + mock_client.health_check.assert_called_once() + + +@pytest.mark.asyncio +async def test_proxy_server_prisma_setup_invalid_db(): + """ + PROD TEST: Test that proxy server startup fails when it's unable to connect to the database + + Think 2-3 times before editing / deleting this test, it's important for PROD + """ + from litellm.proxy.proxy_server import ProxyStartupEvent + from litellm.proxy.utils import ProxyLogging + from litellm.caching import DualCache + + user_api_key_cache = DualCache() + invalid_db_url = "postgresql://invalid:invalid@localhost:5432/nonexistent" + + _old_db_url = os.getenv("DATABASE_URL") + os.environ["DATABASE_URL"] = invalid_db_url + + with pytest.raises(Exception) as exc_info: + await ProxyStartupEvent._setup_prisma_client( + database_url=invalid_db_url, + proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache), + user_api_key_cache=user_api_key_cache, + ) + print("GOT EXCEPTION=", exc_info) + + assert "httpx.ConnectError" in str(exc_info.value) + + # # Verify the error message indicates a database connection issue + # assert any(x in str(exc_info.value).lower() for x in ["database", "connection", "authentication"]) + + if _old_db_url: + os.environ["DATABASE_URL"] = _old_db_url