mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(prisma_client.py): initial commit add prisma migration support to proxy
This commit is contained in:
parent
08c362e1b1
commit
665fdfc788
3 changed files with 156 additions and 21 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -79,3 +79,7 @@ litellm/proxy/_experimental/out/model_hub.html
|
||||||
litellm/proxy/application.log
|
litellm/proxy/application.log
|
||||||
tests/llm_translation/vertex_test_account.json
|
tests/llm_translation/vertex_test_account.json
|
||||||
tests/llm_translation/test_vertex_key.json
|
tests/llm_translation/test_vertex_key.json
|
||||||
|
litellm/proxy/migrations/0_init/migration.sql
|
||||||
|
litellm/proxy/db/migrations/0_init/migration.sql
|
||||||
|
litellm/proxy/db/migrations/*
|
||||||
|
litellm/proxy/migrations/*
|
|
@ -4,11 +4,17 @@ This file contains the PrismaWrapper class, which is used to wrap the Prisma cli
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,6 +118,140 @@ class PrismaWrapper:
|
||||||
return original_attr
|
return original_attr
|
||||||
|
|
||||||
|
|
||||||
|
class PrismaManager:
|
||||||
|
@staticmethod
|
||||||
|
def _get_prisma_dir() -> str:
|
||||||
|
"""Get the path to the migrations directory"""
|
||||||
|
abspath = os.path.abspath(__file__)
|
||||||
|
dname = os.path.dirname(os.path.dirname(abspath))
|
||||||
|
return dname
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_baseline_migration(schema_path: str) -> bool:
|
||||||
|
"""Create a baseline migration for an existing database"""
|
||||||
|
prisma_dir = PrismaManager._get_prisma_dir()
|
||||||
|
prisma_dir_path = Path(prisma_dir)
|
||||||
|
init_dir = prisma_dir_path / "migrations" / "0_init"
|
||||||
|
|
||||||
|
# Create migrations/0_init directory
|
||||||
|
init_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate migration SQL file
|
||||||
|
migration_file = init_dir / "migration.sql"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate migration diff with increased timeout
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"prisma",
|
||||||
|
"migrate",
|
||||||
|
"diff",
|
||||||
|
"--from-empty",
|
||||||
|
"--to-schema-datamodel",
|
||||||
|
str(schema_path),
|
||||||
|
"--script",
|
||||||
|
],
|
||||||
|
stdout=open(migration_file, "w"),
|
||||||
|
check=True,
|
||||||
|
timeout=30,
|
||||||
|
) # 30 second timeout
|
||||||
|
|
||||||
|
# Mark migration as applied with increased timeout
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"prisma",
|
||||||
|
"migrate",
|
||||||
|
"resolve",
|
||||||
|
"--applied",
|
||||||
|
"0_init",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"Migration timed out - the database might be under heavy load."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
verbose_proxy_logger.warning(f"Error creating baseline migration: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setup_database(use_migrate: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Set up the database using either prisma migrate or prisma db push
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if setup was successful, False otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
for attempt in range(4):
|
||||||
|
original_dir = os.getcwd()
|
||||||
|
prisma_dir = PrismaManager._get_prisma_dir()
|
||||||
|
schema_path = prisma_dir + "/schema.prisma"
|
||||||
|
os.chdir(prisma_dir)
|
||||||
|
try:
|
||||||
|
if use_migrate:
|
||||||
|
verbose_proxy_logger.info("Running prisma migrate deploy")
|
||||||
|
# First try to run migrate deploy directly
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["prisma", "migrate", "deploy"],
|
||||||
|
timeout=60,
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.info("prisma migrate deploy completed")
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
# Check if this is the non-empty schema error
|
||||||
|
if (
|
||||||
|
"P3005" in e.stderr
|
||||||
|
and "database schema is not empty" in e.stderr
|
||||||
|
):
|
||||||
|
# Create baseline migration
|
||||||
|
if PrismaManager._create_baseline_migration(schema_path):
|
||||||
|
# Try migrate deploy again after baseline
|
||||||
|
subprocess.run(
|
||||||
|
["prisma", "migrate", "deploy"],
|
||||||
|
timeout=60,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# If it's a different error, raise it
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
# Use prisma db push with increased timeout
|
||||||
|
subprocess.run(
|
||||||
|
["prisma", "db", "push", "--accept-data-loss"],
|
||||||
|
timeout=60,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
|
||||||
|
time.sleep(random.randrange(5, 15))
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
attempts_left = 3 - attempt
|
||||||
|
retry_msg = (
|
||||||
|
f" Retrying... ({attempts_left} attempts left)"
|
||||||
|
if attempts_left > 0
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
f"The process failed to execute. Details: {e}.{retry_msg}"
|
||||||
|
)
|
||||||
|
time.sleep(random.randrange(5, 15))
|
||||||
|
finally:
|
||||||
|
os.chdir(original_dir)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def should_update_prisma_schema(
|
def should_update_prisma_schema(
|
||||||
disable_updates: Optional[Union[bool, str]] = None
|
disable_updates: Optional[Union[bool, str]] = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
|
@ -451,6 +451,12 @@ class ProxyInitializationHelpers:
|
||||||
help="Path to the SSL certfile. Use this when you want to provide SSL certificate when starting proxy",
|
help="Path to the SSL certfile. Use this when you want to provide SSL certificate when starting proxy",
|
||||||
envvar="SSL_CERTFILE_PATH",
|
envvar="SSL_CERTFILE_PATH",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--use_prisma_migrate",
|
||||||
|
is_flag=True,
|
||||||
|
default=False,
|
||||||
|
help="Use prisma migrate instead of prisma db push for database schema updates",
|
||||||
|
)
|
||||||
@click.option("--local", is_flag=True, default=False, help="for local debugging")
|
@click.option("--local", is_flag=True, default=False, help="for local debugging")
|
||||||
def run_server( # noqa: PLR0915
|
def run_server( # noqa: PLR0915
|
||||||
host,
|
host,
|
||||||
|
@ -486,6 +492,7 @@ def run_server( # noqa: PLR0915
|
||||||
ssl_keyfile_path,
|
ssl_keyfile_path,
|
||||||
ssl_certfile_path,
|
ssl_certfile_path,
|
||||||
log_config,
|
log_config,
|
||||||
|
use_prisma_migrate,
|
||||||
):
|
):
|
||||||
args = locals()
|
args = locals()
|
||||||
if local:
|
if local:
|
||||||
|
@ -715,7 +722,10 @@ def run_server( # noqa: PLR0915
|
||||||
|
|
||||||
if is_prisma_runnable:
|
if is_prisma_runnable:
|
||||||
from litellm.proxy.db.check_migration import check_prisma_schema_diff
|
from litellm.proxy.db.check_migration import check_prisma_schema_diff
|
||||||
from litellm.proxy.db.prisma_client import should_update_prisma_schema
|
from litellm.proxy.db.prisma_client import (
|
||||||
|
PrismaManager,
|
||||||
|
should_update_prisma_schema,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
should_update_prisma_schema(
|
should_update_prisma_schema(
|
||||||
|
@ -725,26 +735,7 @@ def run_server( # noqa: PLR0915
|
||||||
):
|
):
|
||||||
check_prisma_schema_diff(db_url=None)
|
check_prisma_schema_diff(db_url=None)
|
||||||
else:
|
else:
|
||||||
for _ in range(4):
|
PrismaManager.setup_database(use_migrate=use_prisma_migrate)
|
||||||
# run prisma db push, before starting server
|
|
||||||
# Save the current working directory
|
|
||||||
original_dir = os.getcwd()
|
|
||||||
# set the working directory to where this script is
|
|
||||||
abspath = os.path.abspath(__file__)
|
|
||||||
dname = os.path.dirname(abspath)
|
|
||||||
os.chdir(dname)
|
|
||||||
try:
|
|
||||||
subprocess.run(
|
|
||||||
["prisma", "db", "push", "--accept-data-loss"]
|
|
||||||
)
|
|
||||||
break # Exit the loop if the subprocess succeeds
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
import time
|
|
||||||
|
|
||||||
print(f"Error: {e}") # noqa
|
|
||||||
time.sleep(random.randrange(start=1, stop=5))
|
|
||||||
finally:
|
|
||||||
os.chdir(original_dir)
|
|
||||||
else:
|
else:
|
||||||
print( # noqa
|
print( # noqa
|
||||||
f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found." # noqa
|
f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found." # noqa
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue