litellm-mirror/litellm/proxy/db/prisma_client.py
Krish Dholakia cdcc8ea9b7
Connect UI to "LiteLLM_DailyUserSpend" spend table - enables usage tab to work at 1m+ spend logs (#9603)
* feat(spend_management_endpoints.py): expose new endpoint for querying user's usage at 1m+ spend logs

Allows user to view their spend at 1m+ spend logs

* build(schema.prisma): add api_requests to dailyuserspend table

* build(migration.sql): add migration file for new column to daily user spend table

* build(prisma_client.py): add logic for copying over migration folder, if deploy/migrations present in expected location

enables easier testing of prisma migration flow

* build(ui/): initial commit successfully using the dailyuserspend table on the UI

* refactor(internal_user_endpoints.py): refactor `/user/daily/activity` to give breakdowns by provider/model/key

* feat: feature parity (cost page) with existing 'usage' page

* build(ui/): add activity tab to new_usage.tsx

gets to feature parity on 'All Up' page of 'usage.tsx'

* fix(proxy/utils.py): count number of api requests in daily user spend table

allows us to see activity by model on new usage tab

* style(new_usage.tsx): fix y-axis to be in ascending order of date

* fix: fix linting errors

* fix: fix ruff check errors
2025-03-27 23:29:15 -07:00

350 lines
13 KiB
Python

"""
This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
"""
import asyncio
import glob
import os
import random
import subprocess
import time
import urllib
import urllib.parse
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.secret_managers.main import str_to_bool
class PrismaWrapper:
def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
self._original_prisma = original_prisma
self.iam_token_db_auth = iam_token_db_auth
def is_token_expired(self, token_url: Optional[str]) -> bool:
if token_url is None:
return True
# Decode the token URL to handle URL-encoded characters
decoded_url = urllib.parse.unquote(token_url)
# Parse the token URL
parsed_url = urllib.parse.urlparse(decoded_url)
# Parse the query parameters from the path component (if they exist there)
query_params = urllib.parse.parse_qs(parsed_url.query)
# Get expiration time from the query parameters
expires = query_params.get("X-Amz-Expires", [None])[0]
if expires is None:
raise ValueError("X-Amz-Expires parameter is missing or invalid.")
expires_int = int(expires)
# Get the token's creation time from the X-Amz-Date parameter
token_time_str = query_params.get("X-Amz-Date", [""])[0]
if not token_time_str:
raise ValueError("X-Amz-Date parameter is missing or invalid.")
# Ensure the token time string is parsed correctly
try:
token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ")
except ValueError as e:
raise ValueError(f"Invalid X-Amz-Date format: {e}")
# Calculate the expiration time
expiration_time = token_time + timedelta(seconds=expires_int)
# Current time in UTC
current_time = datetime.utcnow()
# Check if the token is expired
return current_time > expiration_time
def get_rds_iam_token(self) -> Optional[str]:
if self.iam_token_db_auth:
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
db_host = os.getenv("DATABASE_HOST")
db_port = os.getenv("DATABASE_PORT")
db_user = os.getenv("DATABASE_USER")
db_name = os.getenv("DATABASE_NAME")
db_schema = os.getenv("DATABASE_SCHEMA")
token = generate_iam_auth_token(
db_host=db_host, db_port=db_port, db_user=db_user
)
# print(f"token: {token}")
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
if db_schema:
_db_url += f"?schema={db_schema}"
os.environ["DATABASE_URL"] = _db_url
return _db_url
return None
async def recreate_prisma_client(
self, new_db_url: str, http_client: Optional[Any] = None
):
from prisma import Prisma # type: ignore
if http_client is not None:
self._original_prisma = Prisma(http=http_client)
else:
self._original_prisma = Prisma()
await self._original_prisma.connect()
def __getattr__(self, name: str):
original_attr = getattr(self._original_prisma, name)
if self.iam_token_db_auth:
db_url = os.getenv("DATABASE_URL")
if self.is_token_expired(db_url):
db_url = self.get_rds_iam_token()
loop = asyncio.get_event_loop()
if db_url:
if loop.is_running():
asyncio.run_coroutine_threadsafe(
self.recreate_prisma_client(db_url), loop
)
else:
asyncio.run(self.recreate_prisma_client(db_url))
else:
raise ValueError("Failed to get RDS IAM token")
return original_attr
class PrismaManager:
@staticmethod
def _get_prisma_dir() -> str:
"""Get the path to the migrations directory"""
abspath = os.path.abspath(__file__)
dname = os.path.dirname(os.path.dirname(abspath))
return dname
@staticmethod
def _create_baseline_migration(schema_path: str) -> bool:
"""Create a baseline migration for an existing database"""
prisma_dir = PrismaManager._get_prisma_dir()
prisma_dir_path = Path(prisma_dir)
init_dir = prisma_dir_path / "migrations" / "0_init"
# Create migrations/0_init directory
init_dir.mkdir(parents=True, exist_ok=True)
# Generate migration SQL file
migration_file = init_dir / "migration.sql"
try:
# Generate migration diff with increased timeout
subprocess.run(
[
"prisma",
"migrate",
"diff",
"--from-empty",
"--to-schema-datamodel",
str(schema_path),
"--script",
],
stdout=open(migration_file, "w"),
check=True,
timeout=30,
) # 30 second timeout
# Mark migration as applied with increased timeout
subprocess.run(
[
"prisma",
"migrate",
"resolve",
"--applied",
"0_init",
],
check=True,
timeout=30,
)
return True
except subprocess.TimeoutExpired:
verbose_proxy_logger.warning(
"Migration timed out - the database might be under heavy load."
)
return False
except subprocess.CalledProcessError as e:
verbose_proxy_logger.warning(f"Error creating baseline migration: {e}")
return False
@staticmethod
def _copy_spend_tracking_migrations(prisma_dir: str) -> bool:
import shutil
from pathlib import Path
"""
Check for and copy over spend tracking migrations if they exist in the deploy directory.
Returns True if migrations were found and copied, False otherwise.
"""
try:
# Get the current file's directory
current_dir = Path(__file__).parent
# Check for migrations in the deploy directory (../../deploy/migrations)
deploy_migrations_dir = (
current_dir.parent.parent.parent / "deploy" / "migrations"
)
# Local migrations directory
local_migrations_dir = Path(prisma_dir + "/migrations")
if deploy_migrations_dir.exists():
# Create local migrations directory if it doesn't exist
local_migrations_dir.mkdir(parents=True, exist_ok=True)
# Copy all migration files
# Copy entire migrations folder recursively
shutil.copytree(
deploy_migrations_dir, local_migrations_dir, dirs_exist_ok=True
)
return True
return False
except Exception:
return False
@staticmethod
def _get_migration_names(migrations_dir: str) -> list:
"""Get all migration directory names from the migrations folder"""
migration_paths = glob.glob(f"{migrations_dir}/*/migration.sql")
return [Path(p).parent.name for p in migration_paths]
@staticmethod
def _resolve_all_migrations(migrations_dir: str):
"""Mark all existing migrations as applied"""
migration_names = PrismaManager._get_migration_names(migrations_dir)
for migration_name in migration_names:
try:
verbose_proxy_logger.info(f"Resolving migration: {migration_name}")
subprocess.run(
["prisma", "migrate", "resolve", "--applied", migration_name],
timeout=60,
check=True,
capture_output=True,
text=True,
)
verbose_proxy_logger.debug(f"Resolved migration: {migration_name}")
except subprocess.CalledProcessError as e:
if "is already recorded as applied in the database." not in e.stderr:
verbose_proxy_logger.warning(
f"Failed to resolve migration {migration_name}: {e.stderr}"
)
@staticmethod
def setup_database(use_migrate: bool = False) -> bool:
"""
Set up the database using either prisma migrate or prisma db push
Returns:
bool: True if setup was successful, False otherwise
"""
for attempt in range(4):
original_dir = os.getcwd()
prisma_dir = PrismaManager._get_prisma_dir()
schema_path = prisma_dir + "/schema.prisma"
os.chdir(prisma_dir)
try:
if use_migrate:
PrismaManager._copy_spend_tracking_migrations(
prisma_dir
) # place a migration in the migrations directory
verbose_proxy_logger.info("Running prisma migrate deploy")
try:
subprocess.run(
["prisma", "migrate", "deploy"],
timeout=60,
check=True,
capture_output=True,
text=True,
)
verbose_proxy_logger.info("prisma migrate deploy completed")
# Resolve all migrations in the migrations directory
migrations_dir = os.path.join(prisma_dir, "migrations")
PrismaManager._resolve_all_migrations(migrations_dir)
return True
except subprocess.CalledProcessError as e:
verbose_proxy_logger.warning(
f"prisma db error: {e.stderr}, e: {e.stdout}"
)
if (
"P3005" in e.stderr
and "database schema is not empty" in e.stderr
):
verbose_proxy_logger.info("Creating baseline migration")
if PrismaManager._create_baseline_migration(schema_path):
verbose_proxy_logger.info(
"Resolving all migrations after baseline"
)
# Resolve all migrations after baseline
migrations_dir = os.path.join(prisma_dir, "migrations")
PrismaManager._resolve_all_migrations(migrations_dir)
return True
else:
# Use prisma db push with increased timeout
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"],
timeout=60,
check=True,
)
return True
except subprocess.TimeoutExpired:
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
time.sleep(random.randrange(5, 15))
except subprocess.CalledProcessError as e:
attempts_left = 3 - attempt
retry_msg = (
f" Retrying... ({attempts_left} attempts left)"
if attempts_left > 0
else ""
)
verbose_proxy_logger.warning(
f"The process failed to execute. Details: {e}.{retry_msg}"
)
time.sleep(random.randrange(5, 15))
finally:
os.chdir(original_dir)
return False
def should_update_prisma_schema(
disable_updates: Optional[Union[bool, str]] = None
) -> bool:
"""
Determines if Prisma Schema updates should be applied during startup.
Args:
disable_updates: Controls whether schema updates are disabled.
Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
Returns:
bool: True if schema updates should be applied, False if updates are disabled.
Examples:
>>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var
>>> should_update_prisma_schema(True) # Explicitly disable updates
>>> should_update_prisma_schema("false") # Enable updates using string
"""
if disable_updates is None:
disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
if isinstance(disable_updates, str):
disable_updates = str_to_bool(disable_updates)
return not bool(disable_updates)