mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
feat(prisma_client.py): initial commit add prisma migration support to proxy
This commit is contained in:
parent
08c362e1b1
commit
665fdfc788
3 changed files with 156 additions and 21 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -79,3 +79,7 @@ litellm/proxy/_experimental/out/model_hub.html
|
|||
litellm/proxy/application.log
|
||||
tests/llm_translation/vertex_test_account.json
|
||||
tests/llm_translation/test_vertex_key.json
|
||||
litellm/proxy/migrations/0_init/migration.sql
|
||||
litellm/proxy/db/migrations/0_init/migration.sql
|
||||
litellm/proxy/db/migrations/*
|
||||
litellm/proxy/migrations/*
|
|
@ -4,11 +4,17 @@ This file contains the PrismaWrapper class, which is used to wrap the Prisma cli
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
|
||||
|
@ -112,6 +118,140 @@ class PrismaWrapper:
|
|||
return original_attr
|
||||
|
||||
|
||||
class PrismaManager:
|
||||
@staticmethod
|
||||
def _get_prisma_dir() -> str:
|
||||
"""Get the path to the migrations directory"""
|
||||
abspath = os.path.abspath(__file__)
|
||||
dname = os.path.dirname(os.path.dirname(abspath))
|
||||
return dname
|
||||
|
||||
@staticmethod
|
||||
def _create_baseline_migration(schema_path: str) -> bool:
|
||||
"""Create a baseline migration for an existing database"""
|
||||
prisma_dir = PrismaManager._get_prisma_dir()
|
||||
prisma_dir_path = Path(prisma_dir)
|
||||
init_dir = prisma_dir_path / "migrations" / "0_init"
|
||||
|
||||
# Create migrations/0_init directory
|
||||
init_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate migration SQL file
|
||||
migration_file = init_dir / "migration.sql"
|
||||
|
||||
try:
|
||||
# Generate migration diff with increased timeout
|
||||
subprocess.run(
|
||||
[
|
||||
"prisma",
|
||||
"migrate",
|
||||
"diff",
|
||||
"--from-empty",
|
||||
"--to-schema-datamodel",
|
||||
str(schema_path),
|
||||
"--script",
|
||||
],
|
||||
stdout=open(migration_file, "w"),
|
||||
check=True,
|
||||
timeout=30,
|
||||
) # 30 second timeout
|
||||
|
||||
# Mark migration as applied with increased timeout
|
||||
subprocess.run(
|
||||
[
|
||||
"prisma",
|
||||
"migrate",
|
||||
"resolve",
|
||||
"--applied",
|
||||
"0_init",
|
||||
],
|
||||
check=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
return True
|
||||
except subprocess.TimeoutExpired:
|
||||
verbose_proxy_logger.warning(
|
||||
"Migration timed out - the database might be under heavy load."
|
||||
)
|
||||
return False
|
||||
except subprocess.CalledProcessError as e:
|
||||
verbose_proxy_logger.warning(f"Error creating baseline migration: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def setup_database(use_migrate: bool = False) -> bool:
|
||||
"""
|
||||
Set up the database using either prisma migrate or prisma db push
|
||||
|
||||
Returns:
|
||||
bool: True if setup was successful, False otherwise
|
||||
"""
|
||||
|
||||
for attempt in range(4):
|
||||
original_dir = os.getcwd()
|
||||
prisma_dir = PrismaManager._get_prisma_dir()
|
||||
schema_path = prisma_dir + "/schema.prisma"
|
||||
os.chdir(prisma_dir)
|
||||
try:
|
||||
if use_migrate:
|
||||
verbose_proxy_logger.info("Running prisma migrate deploy")
|
||||
# First try to run migrate deploy directly
|
||||
try:
|
||||
subprocess.run(
|
||||
["prisma", "migrate", "deploy"],
|
||||
timeout=60,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
verbose_proxy_logger.info("prisma migrate deploy completed")
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
# Check if this is the non-empty schema error
|
||||
if (
|
||||
"P3005" in e.stderr
|
||||
and "database schema is not empty" in e.stderr
|
||||
):
|
||||
# Create baseline migration
|
||||
if PrismaManager._create_baseline_migration(schema_path):
|
||||
# Try migrate deploy again after baseline
|
||||
subprocess.run(
|
||||
["prisma", "migrate", "deploy"],
|
||||
timeout=60,
|
||||
check=True,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# If it's a different error, raise it
|
||||
raise e
|
||||
else:
|
||||
# Use prisma db push with increased timeout
|
||||
subprocess.run(
|
||||
["prisma", "db", "push", "--accept-data-loss"],
|
||||
timeout=60,
|
||||
check=True,
|
||||
)
|
||||
return True
|
||||
except subprocess.TimeoutExpired:
|
||||
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
|
||||
time.sleep(random.randrange(5, 15))
|
||||
except subprocess.CalledProcessError as e:
|
||||
attempts_left = 3 - attempt
|
||||
retry_msg = (
|
||||
f" Retrying... ({attempts_left} attempts left)"
|
||||
if attempts_left > 0
|
||||
else ""
|
||||
)
|
||||
verbose_proxy_logger.warning(
|
||||
f"The process failed to execute. Details: {e}.{retry_msg}"
|
||||
)
|
||||
time.sleep(random.randrange(5, 15))
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
return False
|
||||
|
||||
|
||||
def should_update_prisma_schema(
|
||||
disable_updates: Optional[Union[bool, str]] = None
|
||||
) -> bool:
|
||||
|
|
|
@ -451,6 +451,12 @@ class ProxyInitializationHelpers:
|
|||
help="Path to the SSL certfile. Use this when you want to provide SSL certificate when starting proxy",
|
||||
envvar="SSL_CERTFILE_PATH",
|
||||
)
|
||||
@click.option(
|
||||
"--use_prisma_migrate",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Use prisma migrate instead of prisma db push for database schema updates",
|
||||
)
|
||||
@click.option("--local", is_flag=True, default=False, help="for local debugging")
|
||||
def run_server( # noqa: PLR0915
|
||||
host,
|
||||
|
@ -486,6 +492,7 @@ def run_server( # noqa: PLR0915
|
|||
ssl_keyfile_path,
|
||||
ssl_certfile_path,
|
||||
log_config,
|
||||
use_prisma_migrate,
|
||||
):
|
||||
args = locals()
|
||||
if local:
|
||||
|
@ -715,7 +722,10 @@ def run_server( # noqa: PLR0915
|
|||
|
||||
if is_prisma_runnable:
|
||||
from litellm.proxy.db.check_migration import check_prisma_schema_diff
|
||||
from litellm.proxy.db.prisma_client import should_update_prisma_schema
|
||||
from litellm.proxy.db.prisma_client import (
|
||||
PrismaManager,
|
||||
should_update_prisma_schema,
|
||||
)
|
||||
|
||||
if (
|
||||
should_update_prisma_schema(
|
||||
|
@ -725,26 +735,7 @@ def run_server( # noqa: PLR0915
|
|||
):
|
||||
check_prisma_schema_diff(db_url=None)
|
||||
else:
|
||||
for _ in range(4):
|
||||
# run prisma db push, before starting server
|
||||
# Save the current working directory
|
||||
original_dir = os.getcwd()
|
||||
# set the working directory to where this script is
|
||||
abspath = os.path.abspath(__file__)
|
||||
dname = os.path.dirname(abspath)
|
||||
os.chdir(dname)
|
||||
try:
|
||||
subprocess.run(
|
||||
["prisma", "db", "push", "--accept-data-loss"]
|
||||
)
|
||||
break # Exit the loop if the subprocess succeeds
|
||||
except subprocess.CalledProcessError as e:
|
||||
import time
|
||||
|
||||
print(f"Error: {e}") # noqa
|
||||
time.sleep(random.randrange(start=1, stop=5))
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
PrismaManager.setup_database(use_migrate=use_prisma_migrate)
|
||||
else:
|
||||
print( # noqa
|
||||
f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found." # noqa
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue