diff --git a/.gitignore b/.gitignore index d35923f7c3..dab6d4ec81 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,7 @@ litellm/proxy/_experimental/out/model_hub.html litellm/proxy/application.log tests/llm_translation/vertex_test_account.json tests/llm_translation/test_vertex_key.json +litellm/proxy/migrations/0_init/migration.sql +litellm/proxy/db/migrations/0_init/migration.sql +litellm/proxy/db/migrations/* +litellm/proxy/migrations/* \ No newline at end of file diff --git a/litellm/proxy/db/prisma_client.py b/litellm/proxy/db/prisma_client.py index f8bb6b09fd..f8af46c27f 100644 --- a/litellm/proxy/db/prisma_client.py +++ b/litellm/proxy/db/prisma_client.py @@ -4,11 +4,17 @@ This file contains the PrismaWrapper class, which is used to wrap the Prisma cli import asyncio import os +import random +import shutil +import subprocess +import time import urllib import urllib.parse from datetime import datetime, timedelta +from pathlib import Path from typing import Any, Optional, Union +from litellm._logging import verbose_proxy_logger from litellm.secret_managers.main import str_to_bool @@ -112,6 +118,140 @@ class PrismaWrapper: return original_attr +class PrismaManager: + @staticmethod + def _get_prisma_dir() -> str: + """Get the path to the migrations directory""" + abspath = os.path.abspath(__file__) + dname = os.path.dirname(os.path.dirname(abspath)) + return dname + + @staticmethod + def _create_baseline_migration(schema_path: str) -> bool: + """Create a baseline migration for an existing database""" + prisma_dir = PrismaManager._get_prisma_dir() + prisma_dir_path = Path(prisma_dir) + init_dir = prisma_dir_path / "migrations" / "0_init" + + # Create migrations/0_init directory + init_dir.mkdir(parents=True, exist_ok=True) + + # Generate migration SQL file + migration_file = init_dir / "migration.sql" + + try: + # Generate migration diff with increased timeout + subprocess.run( + [ + "prisma", + "migrate", + "diff", + "--from-empty", + "--to-schema-datamodel", + str(schema_path), + "--script", + ], + stdout=open(migration_file, "w"), + check=True, + timeout=30, + ) # 30 second timeout + + # Mark migration as applied with increased timeout + subprocess.run( + [ + "prisma", + "migrate", + "resolve", + "--applied", + "0_init", + ], + check=True, + timeout=30, + ) + + return True + except subprocess.TimeoutExpired: + verbose_proxy_logger.warning( + "Migration timed out - the database might be under heavy load." + ) + return False + except subprocess.CalledProcessError as e: + verbose_proxy_logger.warning(f"Error creating baseline migration: {e}") + return False + + @staticmethod + def setup_database(use_migrate: bool = False) -> bool: + """ + Set up the database using either prisma migrate or prisma db push + + Returns: + bool: True if setup was successful, False otherwise + """ + + for attempt in range(4): + original_dir = os.getcwd() + prisma_dir = PrismaManager._get_prisma_dir() + schema_path = prisma_dir + "/schema.prisma" + os.chdir(prisma_dir) + try: + if use_migrate: + verbose_proxy_logger.info("Running prisma migrate deploy") + # First try to run migrate deploy directly + try: + subprocess.run( + ["prisma", "migrate", "deploy"], + timeout=60, + check=True, + capture_output=True, + text=True, + ) + verbose_proxy_logger.info("prisma migrate deploy completed") + return True + except subprocess.CalledProcessError as e: + # Check if this is the non-empty schema error + if ( + "P3005" in e.stderr + and "database schema is not empty" in e.stderr + ): + # Create baseline migration + if PrismaManager._create_baseline_migration(schema_path): + # Try migrate deploy again after baseline + subprocess.run( + ["prisma", "migrate", "deploy"], + timeout=60, + check=True, + ) + return True + else: + # If it's a different error, raise it + raise e + else: + # Use prisma db push with increased timeout + subprocess.run( + ["prisma", "db", "push", "--accept-data-loss"], + timeout=60, + check=True, + ) + return True + except subprocess.TimeoutExpired: + verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out") + time.sleep(random.randrange(5, 15)) + except subprocess.CalledProcessError as e: + attempts_left = 3 - attempt + retry_msg = ( + f" Retrying... ({attempts_left} attempts left)" + if attempts_left > 0 + else "" + ) + verbose_proxy_logger.warning( + f"The process failed to execute. Details: {e}.{retry_msg}" + ) + time.sleep(random.randrange(5, 15)) + finally: + os.chdir(original_dir) + return False + + def should_update_prisma_schema( disable_updates: Optional[Union[bool, str]] = None ) -> bool: diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 3a885c87c9..8196eb597e 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -451,6 +451,12 @@ class ProxyInitializationHelpers: help="Path to the SSL certfile. Use this when you want to provide SSL certificate when starting proxy", envvar="SSL_CERTFILE_PATH", ) +@click.option( + "--use_prisma_migrate", + is_flag=True, + default=False, + help="Use prisma migrate instead of prisma db push for database schema updates", +) @click.option("--local", is_flag=True, default=False, help="for local debugging") def run_server( # noqa: PLR0915 host, @@ -486,6 +492,7 @@ def run_server( # noqa: PLR0915 ssl_keyfile_path, ssl_certfile_path, log_config, + use_prisma_migrate, ): args = locals() if local: @@ -715,7 +722,10 @@ def run_server( # noqa: PLR0915 if is_prisma_runnable: from litellm.proxy.db.check_migration import check_prisma_schema_diff - from litellm.proxy.db.prisma_client import should_update_prisma_schema + from litellm.proxy.db.prisma_client import ( + PrismaManager, + should_update_prisma_schema, + ) if ( should_update_prisma_schema( @@ -725,26 +735,7 @@ def run_server( # noqa: PLR0915 ): check_prisma_schema_diff(db_url=None) else: - for _ in range(4): - # run prisma db push, before starting server - # Save the current working directory - original_dir = os.getcwd() - # set the working directory to where this script is - abspath = os.path.abspath(__file__) - dname = os.path.dirname(abspath) - os.chdir(dname) - try: - subprocess.run( - ["prisma", "db", "push", "--accept-data-loss"] - ) - break # Exit the loop if the subprocess succeeds - except subprocess.CalledProcessError as e: - import time - - print(f"Error: {e}") # noqa - time.sleep(random.randrange(start=1, stop=5)) - finally: - os.chdir(original_dir) + PrismaManager.setup_database(use_migrate=use_prisma_migrate) else: print( # noqa f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found." # noqa