mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
* feat(spend_management_endpoints.py): expose new endpoint for querying user's usage at 1m+ spend logs Allows user to view their spend at 1m+ spend logs * build(schema.prisma): add api_requests to dailyuserspend table * build(migration.sql): add migration file for new column to daily user spend table * build(prisma_client.py): add logic for copying over migration folder, if deploy/migrations present in expected location enables easier testing of prisma migration flow * build(ui/): initial commit successfully using the dailyuserspend table on the UI * refactor(internal_user_endpoints.py): refactor `/user/daily/activity` to give breakdowns by provider/model/key * feat: feature parity (cost page) with existing 'usage' page * build(ui/): add activity tab to new_usage.tsx gets to feature parity on 'All Up' page of 'usage.tsx' * fix(proxy/utils.py): count number of api requests in daily user spend table allows us to see activity by model on new usage tab * style(new_usage.tsx): fix y-axis to be in ascending order of date * fix: fix linting errors * fix: fix ruff check errors
350 lines
13 KiB
Python
350 lines
13 KiB
Python
"""
|
|
This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
|
|
"""
|
|
|
|
import asyncio
|
|
import glob
|
|
import os
|
|
import random
|
|
import subprocess
|
|
import time
|
|
import urllib
|
|
import urllib.parse
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Any, Optional, Union
|
|
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.secret_managers.main import str_to_bool
|
|
|
|
|
|
class PrismaWrapper:
|
|
def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
|
|
self._original_prisma = original_prisma
|
|
self.iam_token_db_auth = iam_token_db_auth
|
|
|
|
def is_token_expired(self, token_url: Optional[str]) -> bool:
|
|
if token_url is None:
|
|
return True
|
|
# Decode the token URL to handle URL-encoded characters
|
|
decoded_url = urllib.parse.unquote(token_url)
|
|
|
|
# Parse the token URL
|
|
parsed_url = urllib.parse.urlparse(decoded_url)
|
|
|
|
# Parse the query parameters from the path component (if they exist there)
|
|
query_params = urllib.parse.parse_qs(parsed_url.query)
|
|
|
|
# Get expiration time from the query parameters
|
|
expires = query_params.get("X-Amz-Expires", [None])[0]
|
|
if expires is None:
|
|
raise ValueError("X-Amz-Expires parameter is missing or invalid.")
|
|
|
|
expires_int = int(expires)
|
|
|
|
# Get the token's creation time from the X-Amz-Date parameter
|
|
token_time_str = query_params.get("X-Amz-Date", [""])[0]
|
|
if not token_time_str:
|
|
raise ValueError("X-Amz-Date parameter is missing or invalid.")
|
|
|
|
# Ensure the token time string is parsed correctly
|
|
try:
|
|
token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ")
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid X-Amz-Date format: {e}")
|
|
|
|
# Calculate the expiration time
|
|
expiration_time = token_time + timedelta(seconds=expires_int)
|
|
|
|
# Current time in UTC
|
|
current_time = datetime.utcnow()
|
|
|
|
# Check if the token is expired
|
|
return current_time > expiration_time
|
|
|
|
def get_rds_iam_token(self) -> Optional[str]:
|
|
if self.iam_token_db_auth:
|
|
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
|
|
|
db_host = os.getenv("DATABASE_HOST")
|
|
db_port = os.getenv("DATABASE_PORT")
|
|
db_user = os.getenv("DATABASE_USER")
|
|
db_name = os.getenv("DATABASE_NAME")
|
|
db_schema = os.getenv("DATABASE_SCHEMA")
|
|
|
|
token = generate_iam_auth_token(
|
|
db_host=db_host, db_port=db_port, db_user=db_user
|
|
)
|
|
|
|
# print(f"token: {token}")
|
|
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
|
|
if db_schema:
|
|
_db_url += f"?schema={db_schema}"
|
|
|
|
os.environ["DATABASE_URL"] = _db_url
|
|
return _db_url
|
|
return None
|
|
|
|
async def recreate_prisma_client(
|
|
self, new_db_url: str, http_client: Optional[Any] = None
|
|
):
|
|
from prisma import Prisma # type: ignore
|
|
|
|
if http_client is not None:
|
|
self._original_prisma = Prisma(http=http_client)
|
|
else:
|
|
self._original_prisma = Prisma()
|
|
|
|
await self._original_prisma.connect()
|
|
|
|
def __getattr__(self, name: str):
|
|
original_attr = getattr(self._original_prisma, name)
|
|
if self.iam_token_db_auth:
|
|
db_url = os.getenv("DATABASE_URL")
|
|
if self.is_token_expired(db_url):
|
|
db_url = self.get_rds_iam_token()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
if db_url:
|
|
if loop.is_running():
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.recreate_prisma_client(db_url), loop
|
|
)
|
|
else:
|
|
asyncio.run(self.recreate_prisma_client(db_url))
|
|
else:
|
|
raise ValueError("Failed to get RDS IAM token")
|
|
|
|
return original_attr
|
|
|
|
|
|
class PrismaManager:
|
|
@staticmethod
|
|
def _get_prisma_dir() -> str:
|
|
"""Get the path to the migrations directory"""
|
|
abspath = os.path.abspath(__file__)
|
|
dname = os.path.dirname(os.path.dirname(abspath))
|
|
return dname
|
|
|
|
@staticmethod
|
|
def _create_baseline_migration(schema_path: str) -> bool:
|
|
"""Create a baseline migration for an existing database"""
|
|
prisma_dir = PrismaManager._get_prisma_dir()
|
|
prisma_dir_path = Path(prisma_dir)
|
|
init_dir = prisma_dir_path / "migrations" / "0_init"
|
|
|
|
# Create migrations/0_init directory
|
|
init_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Generate migration SQL file
|
|
migration_file = init_dir / "migration.sql"
|
|
|
|
try:
|
|
# Generate migration diff with increased timeout
|
|
subprocess.run(
|
|
[
|
|
"prisma",
|
|
"migrate",
|
|
"diff",
|
|
"--from-empty",
|
|
"--to-schema-datamodel",
|
|
str(schema_path),
|
|
"--script",
|
|
],
|
|
stdout=open(migration_file, "w"),
|
|
check=True,
|
|
timeout=30,
|
|
) # 30 second timeout
|
|
|
|
# Mark migration as applied with increased timeout
|
|
subprocess.run(
|
|
[
|
|
"prisma",
|
|
"migrate",
|
|
"resolve",
|
|
"--applied",
|
|
"0_init",
|
|
],
|
|
check=True,
|
|
timeout=30,
|
|
)
|
|
|
|
return True
|
|
except subprocess.TimeoutExpired:
|
|
verbose_proxy_logger.warning(
|
|
"Migration timed out - the database might be under heavy load."
|
|
)
|
|
return False
|
|
except subprocess.CalledProcessError as e:
|
|
verbose_proxy_logger.warning(f"Error creating baseline migration: {e}")
|
|
return False
|
|
|
|
@staticmethod
|
|
def _copy_spend_tracking_migrations(prisma_dir: str) -> bool:
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
"""
|
|
Check for and copy over spend tracking migrations if they exist in the deploy directory.
|
|
Returns True if migrations were found and copied, False otherwise.
|
|
"""
|
|
try:
|
|
# Get the current file's directory
|
|
current_dir = Path(__file__).parent
|
|
|
|
# Check for migrations in the deploy directory (../../deploy/migrations)
|
|
deploy_migrations_dir = (
|
|
current_dir.parent.parent.parent / "deploy" / "migrations"
|
|
)
|
|
|
|
# Local migrations directory
|
|
local_migrations_dir = Path(prisma_dir + "/migrations")
|
|
|
|
if deploy_migrations_dir.exists():
|
|
# Create local migrations directory if it doesn't exist
|
|
local_migrations_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Copy all migration files
|
|
# Copy entire migrations folder recursively
|
|
shutil.copytree(
|
|
deploy_migrations_dir, local_migrations_dir, dirs_exist_ok=True
|
|
)
|
|
|
|
return True
|
|
return False
|
|
except Exception:
|
|
return False
|
|
|
|
@staticmethod
|
|
def _get_migration_names(migrations_dir: str) -> list:
|
|
"""Get all migration directory names from the migrations folder"""
|
|
migration_paths = glob.glob(f"{migrations_dir}/*/migration.sql")
|
|
return [Path(p).parent.name for p in migration_paths]
|
|
|
|
@staticmethod
|
|
def _resolve_all_migrations(migrations_dir: str):
|
|
"""Mark all existing migrations as applied"""
|
|
migration_names = PrismaManager._get_migration_names(migrations_dir)
|
|
for migration_name in migration_names:
|
|
try:
|
|
verbose_proxy_logger.info(f"Resolving migration: {migration_name}")
|
|
subprocess.run(
|
|
["prisma", "migrate", "resolve", "--applied", migration_name],
|
|
timeout=60,
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
verbose_proxy_logger.debug(f"Resolved migration: {migration_name}")
|
|
except subprocess.CalledProcessError as e:
|
|
if "is already recorded as applied in the database." not in e.stderr:
|
|
verbose_proxy_logger.warning(
|
|
f"Failed to resolve migration {migration_name}: {e.stderr}"
|
|
)
|
|
|
|
@staticmethod
|
|
def setup_database(use_migrate: bool = False) -> bool:
|
|
"""
|
|
Set up the database using either prisma migrate or prisma db push
|
|
|
|
Returns:
|
|
bool: True if setup was successful, False otherwise
|
|
"""
|
|
|
|
for attempt in range(4):
|
|
original_dir = os.getcwd()
|
|
prisma_dir = PrismaManager._get_prisma_dir()
|
|
schema_path = prisma_dir + "/schema.prisma"
|
|
os.chdir(prisma_dir)
|
|
try:
|
|
if use_migrate:
|
|
PrismaManager._copy_spend_tracking_migrations(
|
|
prisma_dir
|
|
) # place a migration in the migrations directory
|
|
verbose_proxy_logger.info("Running prisma migrate deploy")
|
|
try:
|
|
subprocess.run(
|
|
["prisma", "migrate", "deploy"],
|
|
timeout=60,
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
verbose_proxy_logger.info("prisma migrate deploy completed")
|
|
|
|
# Resolve all migrations in the migrations directory
|
|
migrations_dir = os.path.join(prisma_dir, "migrations")
|
|
PrismaManager._resolve_all_migrations(migrations_dir)
|
|
|
|
return True
|
|
except subprocess.CalledProcessError as e:
|
|
verbose_proxy_logger.warning(
|
|
f"prisma db error: {e.stderr}, e: {e.stdout}"
|
|
)
|
|
if (
|
|
"P3005" in e.stderr
|
|
and "database schema is not empty" in e.stderr
|
|
):
|
|
verbose_proxy_logger.info("Creating baseline migration")
|
|
if PrismaManager._create_baseline_migration(schema_path):
|
|
verbose_proxy_logger.info(
|
|
"Resolving all migrations after baseline"
|
|
)
|
|
|
|
# Resolve all migrations after baseline
|
|
migrations_dir = os.path.join(prisma_dir, "migrations")
|
|
PrismaManager._resolve_all_migrations(migrations_dir)
|
|
|
|
return True
|
|
else:
|
|
# Use prisma db push with increased timeout
|
|
subprocess.run(
|
|
["prisma", "db", "push", "--accept-data-loss"],
|
|
timeout=60,
|
|
check=True,
|
|
)
|
|
return True
|
|
except subprocess.TimeoutExpired:
|
|
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
|
|
time.sleep(random.randrange(5, 15))
|
|
except subprocess.CalledProcessError as e:
|
|
attempts_left = 3 - attempt
|
|
retry_msg = (
|
|
f" Retrying... ({attempts_left} attempts left)"
|
|
if attempts_left > 0
|
|
else ""
|
|
)
|
|
verbose_proxy_logger.warning(
|
|
f"The process failed to execute. Details: {e}.{retry_msg}"
|
|
)
|
|
time.sleep(random.randrange(5, 15))
|
|
finally:
|
|
os.chdir(original_dir)
|
|
return False
|
|
|
|
|
|
def should_update_prisma_schema(
|
|
disable_updates: Optional[Union[bool, str]] = None
|
|
) -> bool:
|
|
"""
|
|
Determines if Prisma Schema updates should be applied during startup.
|
|
|
|
Args:
|
|
disable_updates: Controls whether schema updates are disabled.
|
|
Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
|
|
|
|
Returns:
|
|
bool: True if schema updates should be applied, False if updates are disabled.
|
|
|
|
Examples:
|
|
>>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var
|
|
>>> should_update_prisma_schema(True) # Explicitly disable updates
|
|
>>> should_update_prisma_schema("false") # Enable updates using string
|
|
"""
|
|
if disable_updates is None:
|
|
disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
|
|
|
|
if isinstance(disable_updates, str):
|
|
disable_updates = str_to_bool(disable_updates)
|
|
|
|
return not bool(disable_updates)
|