mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into fix/reset-end-user-budget-by-duration
This commit is contained in:
commit
e0e9ee8d7e
1066 changed files with 69987 additions and 14328 deletions
|
@ -10,14 +10,27 @@ import traceback
|
|||
from datetime import datetime, timedelta, timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union, overload
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from litellm.constants import MAX_TEAM_LIST_LIMIT
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
CommonProxyErrors,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
SpendLogsMetadata,
|
||||
SpendLogsPayload,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
|
@ -61,8 +74,10 @@ from litellm.proxy.db.create_views import (
|
|||
create_missing_views,
|
||||
should_create_missing_views,
|
||||
)
|
||||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||||
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
||||
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||
from litellm.proxy.hooks import PROXY_HOOKS, get_proxy_hook
|
||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||
|
@ -71,12 +86,12 @@ from litellm.proxy.hooks.parallel_request_limiter import (
|
|||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
|
||||
from litellm.types.utils import CallTypes, LoggedLiteLLMParams
|
||||
from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
@ -263,6 +278,8 @@ class ProxyLogging:
|
|||
)
|
||||
self.premium_user = premium_user
|
||||
self.service_logging_obj = ServiceLogging()
|
||||
self.db_spend_update_writer = DBSpendUpdateWriter()
|
||||
self.proxy_hook_mapping: Dict[str, CustomLogger] = {}
|
||||
|
||||
def startup_event(
|
||||
self,
|
||||
|
@ -335,11 +352,38 @@ class ProxyLogging:
|
|||
|
||||
if redis_cache is not None:
|
||||
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
|
||||
self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache
|
||||
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
|
||||
|
||||
def _add_proxy_hooks(self, llm_router: Optional[Router] = None):
|
||||
"""
|
||||
Add proxy hooks to litellm.callbacks
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
for hook in PROXY_HOOKS:
|
||||
proxy_hook = get_proxy_hook(hook)
|
||||
import inspect
|
||||
|
||||
expected_args = inspect.getfullargspec(proxy_hook).args
|
||||
passed_in_args: Dict[str, Any] = {}
|
||||
if "internal_usage_cache" in expected_args:
|
||||
passed_in_args["internal_usage_cache"] = self.internal_usage_cache
|
||||
if "prisma_client" in expected_args:
|
||||
passed_in_args["prisma_client"] = prisma_client
|
||||
proxy_hook_obj = cast(CustomLogger, proxy_hook(**passed_in_args))
|
||||
litellm.logging_callback_manager.add_litellm_callback(proxy_hook_obj)
|
||||
|
||||
self.proxy_hook_mapping[hook] = proxy_hook_obj
|
||||
|
||||
def get_proxy_hook(self, hook: str) -> Optional[CustomLogger]:
|
||||
"""
|
||||
Get a proxy hook from the proxy_hook_mapping
|
||||
"""
|
||||
return self.proxy_hook_mapping.get(hook)
|
||||
|
||||
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore
|
||||
self._add_proxy_hooks(llm_router)
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, str):
|
||||
|
@ -914,7 +958,7 @@ class ProxyLogging:
|
|||
async def post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
response: LLMResponseTypes,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
"""
|
||||
|
@ -922,6 +966,9 @@ class ProxyLogging:
|
|||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
3. /image/generation
|
||||
4. /files
|
||||
"""
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
|
@ -1096,12 +1143,6 @@ def jsonify_object(data: dict) -> dict:
|
|||
|
||||
|
||||
class PrismaClient:
|
||||
user_list_transactons: dict = {}
|
||||
end_user_list_transactons: dict = {}
|
||||
key_list_transactons: dict = {}
|
||||
team_list_transactons: dict = {}
|
||||
team_member_list_transactons: dict = {} # key is ["team_id" + "user_id"]
|
||||
org_list_transactons: dict = {}
|
||||
spend_log_transactions: List = []
|
||||
|
||||
def __init__(
|
||||
|
@ -1140,6 +1181,41 @@ class PrismaClient:
|
|||
) # Client to connect to Prisma db
|
||||
verbose_proxy_logger.debug("Success - Created Prisma Client")
|
||||
|
||||
def get_request_status(
|
||||
self, payload: Union[dict, SpendLogsPayload]
|
||||
) -> Literal["success", "failure"]:
|
||||
"""
|
||||
Determine if a request was successful or failed based on payload metadata.
|
||||
|
||||
Args:
|
||||
payload (Union[dict, SpendLogsPayload]): Request payload containing metadata
|
||||
|
||||
Returns:
|
||||
Literal["success", "failure"]: Request status
|
||||
"""
|
||||
try:
|
||||
# Get metadata and convert to dict if it's a JSON string
|
||||
payload_metadata: Union[Dict, SpendLogsMetadata, str] = payload.get(
|
||||
"metadata", {}
|
||||
)
|
||||
if isinstance(payload_metadata, str):
|
||||
payload_metadata_json: Union[Dict, SpendLogsMetadata] = cast(
|
||||
Dict, json.loads(payload_metadata)
|
||||
)
|
||||
else:
|
||||
payload_metadata_json = payload_metadata
|
||||
|
||||
# Check status in metadata dict
|
||||
return (
|
||||
"failure"
|
||||
if payload_metadata_json.get("status") == "failure"
|
||||
else "success"
|
||||
)
|
||||
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
# Default to success if metadata parsing fails
|
||||
return "success"
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
@ -1578,7 +1654,9 @@ class PrismaClient:
|
|||
where={"team_id": {"in": team_id_list}}
|
||||
)
|
||||
elif query_type == "find_all" and team_id_list is None:
|
||||
response = await self.db.litellm_teamtable.find_many(take=20)
|
||||
response = await self.db.litellm_teamtable.find_many(
|
||||
take=MAX_TEAM_LIST_LIMIT
|
||||
)
|
||||
return response
|
||||
elif table_name == "user_notification":
|
||||
if query_type == "find_unique":
|
||||
|
@ -2450,7 +2528,10 @@ def _hash_token_if_needed(token: str) -> str:
|
|||
class ProxyUpdateSpend:
|
||||
@staticmethod
|
||||
async def update_end_user_spend(
|
||||
n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging
|
||||
n_retry_times: int,
|
||||
prisma_client: PrismaClient,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
end_user_list_transactions: Dict[str, float],
|
||||
):
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
|
@ -2462,7 +2543,7 @@ class ProxyUpdateSpend:
|
|||
for (
|
||||
end_user_id,
|
||||
response_cost,
|
||||
) in prisma_client.end_user_list_transactons.items():
|
||||
) in end_user_list_transactions.items():
|
||||
if litellm.max_end_user_budget is not None:
|
||||
pass
|
||||
batcher.litellm_endusertable.upsert(
|
||||
|
@ -2489,10 +2570,6 @@ class ProxyUpdateSpend:
|
|||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
finally:
|
||||
prisma_client.end_user_list_transactons = (
|
||||
{}
|
||||
) # reset the end user list transactions - prevent bad data from causing issues
|
||||
|
||||
@staticmethod
|
||||
async def update_spend_logs(
|
||||
|
@ -2596,202 +2673,11 @@ async def update_spend( # noqa: PLR0915
|
|||
spend_logs: list,
|
||||
"""
|
||||
n_retry_times = 3
|
||||
i = None
|
||||
### UPDATE USER TABLE ###
|
||||
if len(prisma_client.user_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
user_id,
|
||||
response_cost,
|
||||
) in prisma_client.user_list_transactons.items():
|
||||
batcher.litellm_usertable.update_many(
|
||||
where={"user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.user_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE END-USER TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"End-User Spend transactions: {}".format(
|
||||
len(prisma_client.end_user_list_transactons.keys())
|
||||
)
|
||||
await proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler(
|
||||
prisma_client=prisma_client,
|
||||
n_retry_times=n_retry_times,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if len(prisma_client.end_user_list_transactons.keys()) > 0:
|
||||
await ProxyUpdateSpend.update_end_user_spend(
|
||||
n_retry_times=n_retry_times,
|
||||
prisma_client=prisma_client,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
### UPDATE KEY TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"KEY Spend transactions: {}".format(
|
||||
len(prisma_client.key_list_transactons.keys())
|
||||
)
|
||||
)
|
||||
if len(prisma_client.key_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
token,
|
||||
response_cost,
|
||||
) in prisma_client.key_list_transactons.items():
|
||||
batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"token": token},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.key_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE TEAM TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"Team Spend transactions: {}".format(
|
||||
len(prisma_client.team_list_transactons.keys())
|
||||
)
|
||||
)
|
||||
if len(prisma_client.team_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
team_id,
|
||||
response_cost,
|
||||
) in prisma_client.team_list_transactons.items():
|
||||
verbose_proxy_logger.debug(
|
||||
"Updating spend for team id={} by {}".format(
|
||||
team_id, response_cost
|
||||
)
|
||||
)
|
||||
batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"team_id": team_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE TEAM Membership TABLE with spend ###
|
||||
if len(prisma_client.team_member_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
key,
|
||||
response_cost,
|
||||
) in prisma_client.team_member_list_transactons.items():
|
||||
# key is "team_id::<value>::user_id::<value>"
|
||||
team_id = key.split("::")[1]
|
||||
user_id = key.split("::")[3]
|
||||
|
||||
batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"team_id": team_id, "user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_member_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE ORG TABLE ###
|
||||
if len(prisma_client.org_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
org_id,
|
||||
response_cost,
|
||||
) in prisma_client.org_list_transactons.items():
|
||||
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"organization_id": org_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.org_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE SPEND LOGS ###
|
||||
verbose_proxy_logger.debug(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue