(fix) allow gracefully handling DB connection errors on proxy (#7017)

* fix _handle_failed_db_connection_for_get_key_object

* _handle_failed_db_connection_for_get_key_object

* test_auth_not_connected_to_db
This commit is contained in:
Ishaan Jaff 2024-12-03 19:48:51 -08:00 committed by GitHub
parent c32a8caa5e
commit e499d39f9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 2 deletions

View file

@ -51,6 +51,7 @@ from litellm._service_logger import ServiceLogging
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import ( from litellm.proxy.auth.auth_checks import (
_cache_key_object, _cache_key_object,
_handle_failed_db_connection_for_get_key_object,
allowed_routes_check, allowed_routes_check,
can_key_call_model, can_key_call_model,
common_checks, common_checks,
@ -802,7 +803,9 @@ async def user_api_key_auth( # noqa: PLR0915
if ( if (
prisma_client is None prisma_client is None
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.") return await _handle_failed_db_connection_for_get_key_object(
e=Exception("No connected db.")
)
## check for cache hit (In-Memory Cache) ## check for cache hit (In-Memory Cache)
_user_role = None _user_role = None

View file

@ -14,7 +14,7 @@ import pytest
from starlette.datastructures import URL from starlette.datastructures import URL
import litellm import litellm
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth, UserAPIKeyAuth
class Request: class Request:
@ -387,3 +387,34 @@ def test_is_api_route_allowed(route, user_role, expected_result):
pass pass
else: else:
raise e raise e
@pytest.mark.asyncio
async def test_auth_not_connected_to_db():
"""
ensure requests don't fail when `prisma_client` = None
"""
from fastapi import Request
from starlette.datastructures import URL
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
user_key = "sk-12345678"
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", None)
setattr(
litellm.proxy.proxy_server,
"general_settings",
{"allow_requests_on_db_unavailable": True},
)
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
valid_token = await user_api_key_auth(request=request, api_key="Bearer " + user_key)
print("got valid token", valid_token)
assert valid_token.key_name == "failed-to-connect-to-db"
assert valid_token.token == "failed-to-connect-to-db"