expose flag to disable_spend_updates

This commit is contained in:
Ishaan Jaff 2025-03-17 20:45:49 -07:00
parent 7c556a008e
commit e6975a56da
2 changed files with 41 additions and 8 deletions

View file

@ -23,6 +23,7 @@ from typing import (
get_origin, get_origin,
get_type_hints, get_type_hints,
) )
from litellm.types.utils import ( from litellm.types.utils import (
ModelResponse, ModelResponse,
ModelResponseStream, ModelResponseStream,
@ -254,6 +255,7 @@ from litellm.proxy.ui_crud_endpoints.proxy_setting_endpoints import (
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
ProxyLogging, ProxyLogging,
ProxyUpdateSpend,
_cache_user_row, _cache_user_row,
_get_docs_url, _get_docs_url,
_get_projected_spend_over_limit, _get_projected_spend_over_limit,
@ -924,6 +926,8 @@ async def update_database( # noqa: PLR0915
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
) )
if ProxyUpdateSpend.disable_spend_updates() is True:
return
if token is not None and isinstance(token, str) and token.startswith("sk-"): if token is not None and isinstance(token, str) and token.startswith("sk-"):
hashed_token = hash_token(token=token) hashed_token = hash_token(token=token)
else: else:
@ -3047,7 +3051,11 @@ async def async_data_generator(
): ):
verbose_proxy_logger.debug("inside generator") verbose_proxy_logger.debug("inside generator")
try: try:
async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook(user_api_key_dict=user_api_key_dict, response=response, request_data=request_data): async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict,
response=response,
request_data=request_data,
):
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"async_data_generator: received streaming chunk - {}".format(chunk) "async_data_generator: received streaming chunk - {}".format(chunk)
) )

View file

@ -32,7 +32,13 @@ from fastapi import HTTPException, status
import litellm import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging import litellm.litellm_core_utils.litellm_logging
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router, ModelResponseStream from litellm import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
Router,
)
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging, ServiceTypes from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm.caching.caching import DualCache, RedisCache from litellm.caching.caching import DualCache, RedisCache
@ -963,7 +969,9 @@ class ProxyLogging:
async def async_post_call_streaming_hook( async def async_post_call_streaming_hook(
self, self,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse], response: Union[
ModelResponse, EmbeddingResponse, ImageResponse, ModelResponseStream
],
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
): ):
""" """
@ -1009,19 +1017,24 @@ class ProxyLogging:
for callback in litellm.callbacks: for callback in litellm.callbacks:
_callback: Optional[CustomLogger] = None _callback: Optional[CustomLogger] = None
if isinstance(callback, str): if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(callback) _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else: else:
_callback = callback # type: ignore _callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger): if _callback is not None and isinstance(_callback, CustomLogger):
if not isinstance(_callback, CustomGuardrail) or _callback.should_run_guardrail( if not isinstance(
_callback, CustomGuardrail
) or _callback.should_run_guardrail(
data=request_data, event_type=GuardrailEventHooks.post_call data=request_data, event_type=GuardrailEventHooks.post_call
): ):
response = _callback.async_post_call_streaming_iterator_hook( response = _callback.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict, response=response, request_data=request_data user_api_key_dict=user_api_key_dict,
response=response,
request_data=request_data,
) )
return response return response
async def post_call_streaming_hook( async def post_call_streaming_hook(
self, self,
response: str, response: str,
@ -2469,6 +2482,18 @@ class ProxyUpdateSpend:
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
) )
@staticmethod
def disable_spend_updates() -> bool:
"""
returns True if should not update spend in db
Skips writing spend logs and updates to key, team, user spend to DB
"""
from litellm.proxy.proxy_server import general_settings
if general_settings.get("disable_spend_updates") is True:
return True
return False
async def update_spend( # noqa: PLR0915 async def update_spend( # noqa: PLR0915
prisma_client: PrismaClient, prisma_client: PrismaClient,