feat(prisma_client.py): initial commit add prisma migration support to proxy

This commit is contained in:
Krrish Dholakia 2025-03-19 14:26:59 -07:00
parent 08c362e1b1
commit 665fdfc788
3 changed files with 156 additions and 21 deletions

4
.gitignore vendored
View file

@ -79,3 +79,7 @@ litellm/proxy/_experimental/out/model_hub.html
litellm/proxy/application.log
tests/llm_translation/vertex_test_account.json
tests/llm_translation/test_vertex_key.json
litellm/proxy/migrations/0_init/migration.sql
litellm/proxy/db/migrations/0_init/migration.sql
litellm/proxy/db/migrations/*
litellm/proxy/migrations/*

View file

@ -4,11 +4,17 @@ This file contains the PrismaWrapper class, which is used to wrap the Prisma cli
import asyncio
import os
import random
import shutil
import subprocess
import time
import urllib
import urllib.parse
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.secret_managers.main import str_to_bool
@ -112,6 +118,140 @@ class PrismaWrapper:
return original_attr
class PrismaManager:
@staticmethod
def _get_prisma_dir() -> str:
"""Get the path to the migrations directory"""
abspath = os.path.abspath(__file__)
dname = os.path.dirname(os.path.dirname(abspath))
return dname
@staticmethod
def _create_baseline_migration(schema_path: str) -> bool:
"""Create a baseline migration for an existing database"""
prisma_dir = PrismaManager._get_prisma_dir()
prisma_dir_path = Path(prisma_dir)
init_dir = prisma_dir_path / "migrations" / "0_init"
# Create migrations/0_init directory
init_dir.mkdir(parents=True, exist_ok=True)
# Generate migration SQL file
migration_file = init_dir / "migration.sql"
try:
# Generate migration diff with increased timeout
subprocess.run(
[
"prisma",
"migrate",
"diff",
"--from-empty",
"--to-schema-datamodel",
str(schema_path),
"--script",
],
stdout=open(migration_file, "w"),
check=True,
timeout=30,
) # 30 second timeout
# Mark migration as applied with increased timeout
subprocess.run(
[
"prisma",
"migrate",
"resolve",
"--applied",
"0_init",
],
check=True,
timeout=30,
)
return True
except subprocess.TimeoutExpired:
verbose_proxy_logger.warning(
"Migration timed out - the database might be under heavy load."
)
return False
except subprocess.CalledProcessError as e:
verbose_proxy_logger.warning(f"Error creating baseline migration: {e}")
return False
@staticmethod
def setup_database(use_migrate: bool = False) -> bool:
"""
Set up the database using either prisma migrate or prisma db push
Returns:
bool: True if setup was successful, False otherwise
"""
for attempt in range(4):
original_dir = os.getcwd()
prisma_dir = PrismaManager._get_prisma_dir()
schema_path = prisma_dir + "/schema.prisma"
os.chdir(prisma_dir)
try:
if use_migrate:
verbose_proxy_logger.info("Running prisma migrate deploy")
# First try to run migrate deploy directly
try:
subprocess.run(
["prisma", "migrate", "deploy"],
timeout=60,
check=True,
capture_output=True,
text=True,
)
verbose_proxy_logger.info("prisma migrate deploy completed")
return True
except subprocess.CalledProcessError as e:
# Check if this is the non-empty schema error
if (
"P3005" in e.stderr
and "database schema is not empty" in e.stderr
):
# Create baseline migration
if PrismaManager._create_baseline_migration(schema_path):
# Try migrate deploy again after baseline
subprocess.run(
["prisma", "migrate", "deploy"],
timeout=60,
check=True,
)
return True
else:
# If it's a different error, raise it
raise e
else:
# Use prisma db push with increased timeout
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"],
timeout=60,
check=True,
)
return True
except subprocess.TimeoutExpired:
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
time.sleep(random.randrange(5, 15))
except subprocess.CalledProcessError as e:
attempts_left = 3 - attempt
retry_msg = (
f" Retrying... ({attempts_left} attempts left)"
if attempts_left > 0
else ""
)
verbose_proxy_logger.warning(
f"The process failed to execute. Details: {e}.{retry_msg}"
)
time.sleep(random.randrange(5, 15))
finally:
os.chdir(original_dir)
return False
def should_update_prisma_schema(
disable_updates: Optional[Union[bool, str]] = None
) -> bool:

View file

@ -451,6 +451,12 @@ class ProxyInitializationHelpers:
help="Path to the SSL certfile. Use this when you want to provide SSL certificate when starting proxy",
envvar="SSL_CERTFILE_PATH",
)
@click.option(
"--use_prisma_migrate",
is_flag=True,
default=False,
help="Use prisma migrate instead of prisma db push for database schema updates",
)
@click.option("--local", is_flag=True, default=False, help="for local debugging")
def run_server( # noqa: PLR0915
host,
@ -486,6 +492,7 @@ def run_server( # noqa: PLR0915
ssl_keyfile_path,
ssl_certfile_path,
log_config,
use_prisma_migrate,
):
args = locals()
if local:
@ -715,7 +722,10 @@ def run_server( # noqa: PLR0915
if is_prisma_runnable:
from litellm.proxy.db.check_migration import check_prisma_schema_diff
from litellm.proxy.db.prisma_client import should_update_prisma_schema
from litellm.proxy.db.prisma_client import (
PrismaManager,
should_update_prisma_schema,
)
if (
should_update_prisma_schema(
@ -725,26 +735,7 @@ def run_server( # noqa: PLR0915
):
check_prisma_schema_diff(db_url=None)
else:
for _ in range(4):
# run prisma db push, before starting server
# Save the current working directory
original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try:
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
)
break # Exit the loop if the subprocess succeeds
except subprocess.CalledProcessError as e:
import time
print(f"Error: {e}") # noqa
time.sleep(random.randrange(start=1, stop=5))
finally:
os.chdir(original_dir)
PrismaManager.setup_database(use_migrate=use_prisma_migrate)
else:
print( # noqa
f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found." # noqa