Merge branch 'main' into fix/reset-end-user-budget-by-duration

This commit is contained in:
Laurien 2025-04-23 14:20:54 +02:00 committed by GitHub
commit e0e9ee8d7e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1066 changed files with 69987 additions and 14328 deletions

View file

@ -10,14 +10,27 @@ import traceback
from datetime import datetime, timedelta, timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union, overload
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Union,
cast,
overload,
)
from litellm.constants import MAX_TEAM_LIST_LIMIT
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
CommonProxyErrors,
ProxyErrorTypes,
ProxyException,
SpendLogsMetadata,
SpendLogsPayload,
)
from litellm.types.guardrails import GuardrailEventHooks
@ -61,8 +74,10 @@ from litellm.proxy.db.create_views import (
create_missing_views,
should_create_missing_views,
)
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
from litellm.proxy.db.log_db_metrics import log_db_metrics
from litellm.proxy.db.prisma_client import PrismaWrapper
from litellm.proxy.hooks import PROXY_HOOKS, get_proxy_hook
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.parallel_request_limiter import (
@ -71,12 +86,12 @@ from litellm.proxy.hooks.parallel_request_limiter import (
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
from litellm.types.utils import CallTypes, LoggedLiteLLMParams
from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any
@ -263,6 +278,8 @@ class ProxyLogging:
)
self.premium_user = premium_user
self.service_logging_obj = ServiceLogging()
self.db_spend_update_writer = DBSpendUpdateWriter()
self.proxy_hook_mapping: Dict[str, CustomLogger] = {}
def startup_event(
self,
@ -335,11 +352,38 @@ class ProxyLogging:
if redis_cache is not None:
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
def _add_proxy_hooks(self, llm_router: Optional[Router] = None):
"""
Add proxy hooks to litellm.callbacks
"""
from litellm.proxy.proxy_server import prisma_client
for hook in PROXY_HOOKS:
proxy_hook = get_proxy_hook(hook)
import inspect
expected_args = inspect.getfullargspec(proxy_hook).args
passed_in_args: Dict[str, Any] = {}
if "internal_usage_cache" in expected_args:
passed_in_args["internal_usage_cache"] = self.internal_usage_cache
if "prisma_client" in expected_args:
passed_in_args["prisma_client"] = prisma_client
proxy_hook_obj = cast(CustomLogger, proxy_hook(**passed_in_args))
litellm.logging_callback_manager.add_litellm_callback(proxy_hook_obj)
self.proxy_hook_mapping[hook] = proxy_hook_obj
def get_proxy_hook(self, hook: str) -> Optional[CustomLogger]:
"""
Get a proxy hook from the proxy_hook_mapping
"""
return self.proxy_hook_mapping.get(hook)
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore
self._add_proxy_hooks(llm_router)
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
for callback in litellm.callbacks:
if isinstance(callback, str):
@ -914,7 +958,7 @@ class ProxyLogging:
async def post_call_success_hook(
self,
data: dict,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
response: LLMResponseTypes,
user_api_key_dict: UserAPIKeyAuth,
):
"""
@ -922,6 +966,9 @@ class ProxyLogging:
Covers:
1. /chat/completions
2. /embeddings
3. /image/generation
4. /files
"""
for callback in litellm.callbacks:
@ -1096,12 +1143,6 @@ def jsonify_object(data: dict) -> dict:
class PrismaClient:
user_list_transactons: dict = {}
end_user_list_transactons: dict = {}
key_list_transactons: dict = {}
team_list_transactons: dict = {}
team_member_list_transactons: dict = {} # key is ["team_id" + "user_id"]
org_list_transactons: dict = {}
spend_log_transactions: List = []
def __init__(
@ -1140,6 +1181,41 @@ class PrismaClient:
) # Client to connect to Prisma db
verbose_proxy_logger.debug("Success - Created Prisma Client")
def get_request_status(
self, payload: Union[dict, SpendLogsPayload]
) -> Literal["success", "failure"]:
"""
Determine if a request was successful or failed based on payload metadata.
Args:
payload (Union[dict, SpendLogsPayload]): Request payload containing metadata
Returns:
Literal["success", "failure"]: Request status
"""
try:
# Get metadata and convert to dict if it's a JSON string
payload_metadata: Union[Dict, SpendLogsMetadata, str] = payload.get(
"metadata", {}
)
if isinstance(payload_metadata, str):
payload_metadata_json: Union[Dict, SpendLogsMetadata] = cast(
Dict, json.loads(payload_metadata)
)
else:
payload_metadata_json = payload_metadata
# Check status in metadata dict
return (
"failure"
if payload_metadata_json.get("status") == "failure"
else "success"
)
except (json.JSONDecodeError, AttributeError):
# Default to success if metadata parsing fails
return "success"
def hash_token(self, token: str):
# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()
@ -1578,7 +1654,9 @@ class PrismaClient:
where={"team_id": {"in": team_id_list}}
)
elif query_type == "find_all" and team_id_list is None:
response = await self.db.litellm_teamtable.find_many(take=20)
response = await self.db.litellm_teamtable.find_many(
take=MAX_TEAM_LIST_LIMIT
)
return response
elif table_name == "user_notification":
if query_type == "find_unique":
@ -2450,7 +2528,10 @@ def _hash_token_if_needed(token: str) -> str:
class ProxyUpdateSpend:
@staticmethod
async def update_end_user_spend(
n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
end_user_list_transactions: Dict[str, float],
):
for i in range(n_retry_times + 1):
start_time = time.time()
@ -2462,7 +2543,7 @@ class ProxyUpdateSpend:
for (
end_user_id,
response_cost,
) in prisma_client.end_user_list_transactons.items():
) in end_user_list_transactions.items():
if litellm.max_end_user_budget is not None:
pass
batcher.litellm_endusertable.upsert(
@ -2489,10 +2570,6 @@ class ProxyUpdateSpend:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
finally:
prisma_client.end_user_list_transactons = (
{}
) # reset the end user list transactions - prevent bad data from causing issues
@staticmethod
async def update_spend_logs(
@ -2596,202 +2673,11 @@ async def update_spend( # noqa: PLR0915
spend_logs: list,
"""
n_retry_times = 3
i = None
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
user_id,
response_cost,
) in prisma_client.user_list_transactons.items():
batcher.litellm_usertable.update_many(
where={"user_id": user_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.user_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times: # If we've reached the maximum number of retries
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
### UPDATE END-USER TABLE ###
verbose_proxy_logger.debug(
"End-User Spend transactions: {}".format(
len(prisma_client.end_user_list_transactons.keys())
)
await proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler(
prisma_client=prisma_client,
n_retry_times=n_retry_times,
proxy_logging_obj=proxy_logging_obj,
)
if len(prisma_client.end_user_list_transactons.keys()) > 0:
await ProxyUpdateSpend.update_end_user_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
)
### UPDATE KEY TABLE ###
verbose_proxy_logger.debug(
"KEY Spend transactions: {}".format(
len(prisma_client.key_list_transactons.keys())
)
)
if len(prisma_client.key_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
token,
response_cost,
) in prisma_client.key_list_transactons.items():
batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists
where={"token": token},
data={"spend": {"increment": response_cost}},
)
prisma_client.key_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times: # If we've reached the maximum number of retries
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
### UPDATE TEAM TABLE ###
verbose_proxy_logger.debug(
"Team Spend transactions: {}".format(
len(prisma_client.team_list_transactons.keys())
)
)
if len(prisma_client.team_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
team_id,
response_cost,
) in prisma_client.team_list_transactons.items():
verbose_proxy_logger.debug(
"Updating spend for team id={} by {}".format(
team_id, response_cost
)
)
batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists
where={"team_id": team_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.team_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times: # If we've reached the maximum number of retries
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
### UPDATE TEAM Membership TABLE with spend ###
if len(prisma_client.team_member_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
key,
response_cost,
) in prisma_client.team_member_list_transactons.items():
# key is "team_id::<value>::user_id::<value>"
team_id = key.split("::")[1]
user_id = key.split("::")[3]
batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists
where={"team_id": team_id, "user_id": user_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.team_member_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times: # If we've reached the maximum number of retries
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
org_id,
response_cost,
) in prisma_client.org_list_transactons.items():
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists
where={"organization_id": org_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.org_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times: # If we've reached the maximum number of retries
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
### UPDATE SPEND LOGS ###
verbose_proxy_logger.debug(