forked from phoenix/litellm-mirror
(fix) ProxyStartup - Check that prisma connection is healthy when starting an instance of LiteLLM (#6627)
* fix debug statements * fix assert prisma_client.health_check is called on _setup * asser that _setup_prisma_client is called on startup proxy * fix prisma client health_check * add test_bad_database_url * add strict checks on db startup * temp remove fix to validate if check works as expected * add health_check back * test_proxy_server_prisma_setup_invalid_db
This commit is contained in:
parent
8a2b6fd8d2
commit
373f9d409e
4 changed files with 86 additions and 10 deletions
|
@ -986,6 +986,41 @@ jobs:
|
|||
- store_test_results:
|
||||
path: test-results
|
||||
|
||||
test_bad_database_url:
|
||||
machine:
|
||||
image: ubuntu-2204:2023.10.1
|
||||
resource_class: xlarge
|
||||
working_directory: ~/project
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: |
|
||||
docker build -t myapp . -f ./docker/Dockerfile.non_root
|
||||
- run:
|
||||
name: Run Docker container with bad DATABASE_URL
|
||||
command: |
|
||||
docker run --name my-app \
|
||||
-p 4000:4000 \
|
||||
-e DATABASE_URL="postgresql://wrong:wrong@wrong:5432/wrong" \
|
||||
myapp:latest \
|
||||
--port 4000 > docker_output.log 2>&1 || true
|
||||
- run:
|
||||
name: Display Docker logs
|
||||
command: cat docker_output.log
|
||||
- run:
|
||||
name: Check for expected error
|
||||
command: |
|
||||
if grep -q "Error: P1001: Can't reach database server at" docker_output.log && \
|
||||
grep -q "httpx.ConnectError: All connection attempts failed" docker_output.log && \
|
||||
grep -q "ERROR: Application startup failed. Exiting." docker_output.log; then
|
||||
echo "Expected error found. Test passed."
|
||||
else
|
||||
echo "Expected error not found. Test failed."
|
||||
cat docker_output.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
build_and_test:
|
||||
|
@ -1082,11 +1117,18 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- test_bad_database_url:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- publish_to_pypi:
|
||||
requires:
|
||||
- local_testing
|
||||
- build_and_test
|
||||
- load_testing
|
||||
- test_bad_database_url
|
||||
- llm_translation_testing
|
||||
- logging_testing
|
||||
- litellm_router_testing
|
||||
|
|
|
@ -3052,6 +3052,8 @@ class ProxyStartupEvent:
|
|||
prisma_client.check_view_exists()
|
||||
) # check if all necessary views exist. Don't block execution
|
||||
|
||||
# run a health check to ensure the DB is ready
|
||||
await prisma_client.health_check()
|
||||
return prisma_client
|
||||
|
||||
|
||||
|
|
|
@ -1083,19 +1083,16 @@ class PrismaClient:
|
|||
proxy_logging_obj: ProxyLogging,
|
||||
http_client: Optional[Any] = None,
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||
)
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
self.iam_token_db_auth: Optional[bool] = str_to_bool(
|
||||
os.getenv("IAM_TOKEN_DB_AUTH")
|
||||
)
|
||||
verbose_proxy_logger.debug("Creating Prisma Client..")
|
||||
try:
|
||||
from prisma import Prisma # type: ignore
|
||||
except Exception:
|
||||
raise Exception("Unable to find Prisma binaries.")
|
||||
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
|
||||
if http_client is not None:
|
||||
self.db = PrismaWrapper(
|
||||
original_prisma=Prisma(http=http_client),
|
||||
|
@ -1114,7 +1111,7 @@ class PrismaClient:
|
|||
else False
|
||||
),
|
||||
) # Client to connect to Prisma db
|
||||
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
|
||||
verbose_proxy_logger.debug("Success - Created Prisma Client")
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
|
@ -2348,11 +2345,7 @@ class PrismaClient:
|
|||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
sql_query = """
|
||||
SELECT 1
|
||||
FROM "LiteLLM_VerificationToken"
|
||||
LIMIT 1
|
||||
"""
|
||||
sql_query = "SELECT 1"
|
||||
|
||||
# Execute the raw query
|
||||
# The asterisk before `user_id_list` unpacks the list into separate arguments
|
||||
|
|
|
@ -1911,6 +1911,7 @@ async def test_proxy_server_prisma_setup():
|
|||
mock_client = mock_prisma_client.return_value # This is the mocked instance
|
||||
mock_client.connect = AsyncMock() # Mock the connect method
|
||||
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
|
||||
mock_client.health_check = AsyncMock() # Mock the health_check method
|
||||
|
||||
await ProxyStartupEvent._setup_prisma_client(
|
||||
database_url=os.getenv("DATABASE_URL"),
|
||||
|
@ -1921,3 +1922,41 @@ async def test_proxy_server_prisma_setup():
|
|||
# Verify our mocked methods were called
|
||||
mock_client.connect.assert_called_once()
|
||||
mock_client.check_view_exists.assert_called_once()
|
||||
|
||||
# Note: This is REALLY IMPORTANT to check that the health check is called
|
||||
# This is how we ensure the DB is ready before proceeding
|
||||
mock_client.health_check.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_server_prisma_setup_invalid_db():
|
||||
"""
|
||||
PROD TEST: Test that proxy server startup fails when it's unable to connect to the database
|
||||
|
||||
Think 2-3 times before editing / deleting this test, it's important for PROD
|
||||
"""
|
||||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.caching import DualCache
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
invalid_db_url = "postgresql://invalid:invalid@localhost:5432/nonexistent"
|
||||
|
||||
_old_db_url = os.getenv("DATABASE_URL")
|
||||
os.environ["DATABASE_URL"] = invalid_db_url
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await ProxyStartupEvent._setup_prisma_client(
|
||||
database_url=invalid_db_url,
|
||||
proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
print("GOT EXCEPTION=", exc_info)
|
||||
|
||||
assert "httpx.ConnectError" in str(exc_info.value)
|
||||
|
||||
# # Verify the error message indicates a database connection issue
|
||||
# assert any(x in str(exc_info.value).lower() for x in ["database", "connection", "authentication"])
|
||||
|
||||
if _old_db_url:
|
||||
os.environ["DATABASE_URL"] = _old_db_url
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue