forked from phoenix/litellm-mirror
(fix) ProxyStartup - Check that prisma connection is healthy when starting an instance of LiteLLM (#6627)
* fix debug statements * fix assert prisma_client.health_check is called on _setup * asser that _setup_prisma_client is called on startup proxy * fix prisma client health_check * add test_bad_database_url * add strict checks on db startup * temp remove fix to validate if check works as expected * add health_check back * test_proxy_server_prisma_setup_invalid_db
This commit is contained in:
parent
8a2b6fd8d2
commit
373f9d409e
4 changed files with 86 additions and 10 deletions
|
@ -986,6 +986,41 @@ jobs:
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
|
|
||||||
|
test_bad_database_url:
|
||||||
|
machine:
|
||||||
|
image: ubuntu-2204:2023.10.1
|
||||||
|
resource_class: xlarge
|
||||||
|
working_directory: ~/project
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Build Docker image
|
||||||
|
command: |
|
||||||
|
docker build -t myapp . -f ./docker/Dockerfile.non_root
|
||||||
|
- run:
|
||||||
|
name: Run Docker container with bad DATABASE_URL
|
||||||
|
command: |
|
||||||
|
docker run --name my-app \
|
||||||
|
-p 4000:4000 \
|
||||||
|
-e DATABASE_URL="postgresql://wrong:wrong@wrong:5432/wrong" \
|
||||||
|
myapp:latest \
|
||||||
|
--port 4000 > docker_output.log 2>&1 || true
|
||||||
|
- run:
|
||||||
|
name: Display Docker logs
|
||||||
|
command: cat docker_output.log
|
||||||
|
- run:
|
||||||
|
name: Check for expected error
|
||||||
|
command: |
|
||||||
|
if grep -q "Error: P1001: Can't reach database server at" docker_output.log && \
|
||||||
|
grep -q "httpx.ConnectError: All connection attempts failed" docker_output.log && \
|
||||||
|
grep -q "ERROR: Application startup failed. Exiting." docker_output.log; then
|
||||||
|
echo "Expected error found. Test passed."
|
||||||
|
else
|
||||||
|
echo "Expected error not found. Test failed."
|
||||||
|
cat docker_output.log
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
version: 2
|
version: 2
|
||||||
build_and_test:
|
build_and_test:
|
||||||
|
@ -1082,11 +1117,18 @@ workflows:
|
||||||
only:
|
only:
|
||||||
- main
|
- main
|
||||||
- /litellm_.*/
|
- /litellm_.*/
|
||||||
|
- test_bad_database_url:
|
||||||
|
filters:
|
||||||
|
branches:
|
||||||
|
only:
|
||||||
|
- main
|
||||||
|
- /litellm_.*/
|
||||||
- publish_to_pypi:
|
- publish_to_pypi:
|
||||||
requires:
|
requires:
|
||||||
- local_testing
|
- local_testing
|
||||||
- build_and_test
|
- build_and_test
|
||||||
- load_testing
|
- load_testing
|
||||||
|
- test_bad_database_url
|
||||||
- llm_translation_testing
|
- llm_translation_testing
|
||||||
- logging_testing
|
- logging_testing
|
||||||
- litellm_router_testing
|
- litellm_router_testing
|
||||||
|
|
|
@ -3052,6 +3052,8 @@ class ProxyStartupEvent:
|
||||||
prisma_client.check_view_exists()
|
prisma_client.check_view_exists()
|
||||||
) # check if all necessary views exist. Don't block execution
|
) # check if all necessary views exist. Don't block execution
|
||||||
|
|
||||||
|
# run a health check to ensure the DB is ready
|
||||||
|
await prisma_client.health_check()
|
||||||
return prisma_client
|
return prisma_client
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1083,19 +1083,16 @@ class PrismaClient:
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
http_client: Optional[Any] = None,
|
http_client: Optional[Any] = None,
|
||||||
):
|
):
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
|
||||||
)
|
|
||||||
## init logging object
|
## init logging object
|
||||||
self.proxy_logging_obj = proxy_logging_obj
|
self.proxy_logging_obj = proxy_logging_obj
|
||||||
self.iam_token_db_auth: Optional[bool] = str_to_bool(
|
self.iam_token_db_auth: Optional[bool] = str_to_bool(
|
||||||
os.getenv("IAM_TOKEN_DB_AUTH")
|
os.getenv("IAM_TOKEN_DB_AUTH")
|
||||||
)
|
)
|
||||||
|
verbose_proxy_logger.debug("Creating Prisma Client..")
|
||||||
try:
|
try:
|
||||||
from prisma import Prisma # type: ignore
|
from prisma import Prisma # type: ignore
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception("Unable to find Prisma binaries.")
|
raise Exception("Unable to find Prisma binaries.")
|
||||||
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
|
|
||||||
if http_client is not None:
|
if http_client is not None:
|
||||||
self.db = PrismaWrapper(
|
self.db = PrismaWrapper(
|
||||||
original_prisma=Prisma(http=http_client),
|
original_prisma=Prisma(http=http_client),
|
||||||
|
@ -1114,7 +1111,7 @@ class PrismaClient:
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
) # Client to connect to Prisma db
|
) # Client to connect to Prisma db
|
||||||
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
|
verbose_proxy_logger.debug("Success - Created Prisma Client")
|
||||||
|
|
||||||
def hash_token(self, token: str):
|
def hash_token(self, token: str):
|
||||||
# Hash the string using SHA-256
|
# Hash the string using SHA-256
|
||||||
|
@ -2348,11 +2345,7 @@ class PrismaClient:
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
sql_query = """
|
sql_query = "SELECT 1"
|
||||||
SELECT 1
|
|
||||||
FROM "LiteLLM_VerificationToken"
|
|
||||||
LIMIT 1
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Execute the raw query
|
# Execute the raw query
|
||||||
# The asterisk before `user_id_list` unpacks the list into separate arguments
|
# The asterisk before `user_id_list` unpacks the list into separate arguments
|
||||||
|
|
|
@ -1911,6 +1911,7 @@ async def test_proxy_server_prisma_setup():
|
||||||
mock_client = mock_prisma_client.return_value # This is the mocked instance
|
mock_client = mock_prisma_client.return_value # This is the mocked instance
|
||||||
mock_client.connect = AsyncMock() # Mock the connect method
|
mock_client.connect = AsyncMock() # Mock the connect method
|
||||||
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
|
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
|
||||||
|
mock_client.health_check = AsyncMock() # Mock the health_check method
|
||||||
|
|
||||||
await ProxyStartupEvent._setup_prisma_client(
|
await ProxyStartupEvent._setup_prisma_client(
|
||||||
database_url=os.getenv("DATABASE_URL"),
|
database_url=os.getenv("DATABASE_URL"),
|
||||||
|
@ -1921,3 +1922,41 @@ async def test_proxy_server_prisma_setup():
|
||||||
# Verify our mocked methods were called
|
# Verify our mocked methods were called
|
||||||
mock_client.connect.assert_called_once()
|
mock_client.connect.assert_called_once()
|
||||||
mock_client.check_view_exists.assert_called_once()
|
mock_client.check_view_exists.assert_called_once()
|
||||||
|
|
||||||
|
# Note: This is REALLY IMPORTANT to check that the health check is called
|
||||||
|
# This is how we ensure the DB is ready before proceeding
|
||||||
|
mock_client.health_check.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_proxy_server_prisma_setup_invalid_db():
|
||||||
|
"""
|
||||||
|
PROD TEST: Test that proxy server startup fails when it's unable to connect to the database
|
||||||
|
|
||||||
|
Think 2-3 times before editing / deleting this test, it's important for PROD
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
user_api_key_cache = DualCache()
|
||||||
|
invalid_db_url = "postgresql://invalid:invalid@localhost:5432/nonexistent"
|
||||||
|
|
||||||
|
_old_db_url = os.getenv("DATABASE_URL")
|
||||||
|
os.environ["DATABASE_URL"] = invalid_db_url
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await ProxyStartupEvent._setup_prisma_client(
|
||||||
|
database_url=invalid_db_url,
|
||||||
|
proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
)
|
||||||
|
print("GOT EXCEPTION=", exc_info)
|
||||||
|
|
||||||
|
assert "httpx.ConnectError" in str(exc_info.value)
|
||||||
|
|
||||||
|
# # Verify the error message indicates a database connection issue
|
||||||
|
# assert any(x in str(exc_info.value).lower() for x in ["database", "connection", "authentication"])
|
||||||
|
|
||||||
|
if _old_db_url:
|
||||||
|
os.environ["DATABASE_URL"] = _old_db_url
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue