(proxy fix) - call connect on prisma client when running setup (#6534)

* critical fix - call connect on prisma client when running setup

* fix test_proxy_server_prisma_setup

* fix test_proxy_server_prisma_setup
This commit is contained in:
Ishaan Jaff 2024-11-05 00:43:32 +05:30 committed by GitHub
parent 3b5776e9ec
commit 8bf0005012
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 7 deletions

View file

@ -2993,7 +2993,7 @@ class ProxyStartupEvent:
scheduler.start()
@classmethod
def _setup_prisma_client(
async def _setup_prisma_client(
cls,
database_url: Optional[str],
proxy_logging_obj: ProxyLogging,
@ -3012,6 +3012,8 @@ class ProxyStartupEvent:
except Exception as e:
raise e
await prisma_client.connect()
## Add necessary views to proxy ##
asyncio.create_task(
prisma_client.check_view_exists()
@ -3033,7 +3035,7 @@ async def startup_event():
# check if DATABASE_URL in environment - load from there
if prisma_client is None:
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
prisma_client = ProxyStartupEvent._setup_prisma_client(
prisma_client = await ProxyStartupEvent._setup_prisma_client(
database_url=_db_url,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
@ -3123,9 +3125,6 @@ async def startup_event():
prompt_injection_detection_obj.update_environment(router=llm_router)
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
if prisma_client is not None:
await prisma_client.connect()
if prisma_client is not None and master_key is not None:
ProxyStartupEvent._add_master_key_hash_to_db(
master_key=master_key,

View file

@ -1909,13 +1909,15 @@ async def test_proxy_server_prisma_setup():
litellm.proxy.proxy_server, "PrismaClient", new=MagicMock()
) as mock_prisma_client:
mock_client = mock_prisma_client.return_value # This is the mocked instance
mock_client.connect = AsyncMock() # Mock the connect method
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
ProxyStartupEvent._setup_prisma_client(
await ProxyStartupEvent._setup_prisma_client(
database_url=os.getenv("DATABASE_URL"),
proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
user_api_key_cache=user_api_key_cache,
)
await asyncio.sleep(1)
# Verify our mocked methods were called
mock_client.connect.assert_called_once()
mock_client.check_view_exists.assert_called_once()