Move daily user transaction logging outside of 'disable_spend_logs' flag - different tables (#9772)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 16s
Helm unit test / unit-test (push) Successful in 18s

* refactor(db_spend_update_writer.py): aggregate table is entirely different

* test(test_db_spend_update_writer.py): add unit test to ensure if disable_spend_logs is true daily user transactions is still logged

* test: fix test
This commit is contained in:
Krish Dholakia 2025-04-05 09:58:16 -07:00 committed by GitHub
parent cd0a1e6000
commit 0d503ad8ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 115 additions and 56 deletions

View file

@ -33,8 +33,12 @@ model_list:
litellm_settings:
num_retries: 0
callbacks: ["prometheus"]
# json_logs: true
files_settings:
- custom_llm_provider: gemini
api_key: os.environ/GEMINI_API_KEY
general_settings:
disable_spend_logs: True
disable_error_logs: True

View file

@ -91,6 +91,23 @@ class DBSpendUpdateWriter:
else:
hashed_token = token
## CREATE SPEND LOG PAYLOAD ##
from litellm.proxy.spend_tracking.spend_tracking_utils import (
get_logging_payload,
)
payload = get_logging_payload(
kwargs=kwargs,
response_obj=completion_response,
start_time=start_time,
end_time=end_time,
)
payload["spend"] = response_cost or 0.0
if isinstance(payload["startTime"], datetime):
payload["startTime"] = payload["startTime"].isoformat()
if isinstance(payload["endTime"], datetime):
payload["endTime"] = payload["endTime"].isoformat()
asyncio.create_task(
self._update_user_db(
response_cost=response_cost,
@ -125,11 +142,7 @@ class DBSpendUpdateWriter:
)
if disable_spend_logs is False:
await self._insert_spend_log_to_db(
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
response_cost=response_cost,
payload=payload,
prisma_client=prisma_client,
)
else:
@ -137,6 +150,13 @@ class DBSpendUpdateWriter:
"disable_spend_logs=True. Skipping writing spend logs to db. Other spend updates - Key/User/Team table will still occur."
)
asyncio.create_task(
self.add_spend_log_transaction_to_daily_user_transaction(
payload=payload,
prisma_client=prisma_client,
)
)
verbose_proxy_logger.debug("Runs spend update on all tables")
except Exception:
verbose_proxy_logger.debug(
@ -284,62 +304,25 @@ class DBSpendUpdateWriter:
raise e
async def _insert_spend_log_to_db(
self,
kwargs: Optional[dict],
completion_response: Optional[Union[litellm.ModelResponse, Any, Exception]],
start_time: Optional[datetime],
end_time: Optional[datetime],
response_cost: Optional[float],
prisma_client: Optional[PrismaClient],
):
from litellm.proxy.spend_tracking.spend_tracking_utils import (
get_logging_payload,
)
try:
if prisma_client:
payload = get_logging_payload(
kwargs=kwargs,
response_obj=completion_response,
start_time=start_time,
end_time=end_time,
)
payload["spend"] = response_cost or 0.0
await self._set_spend_logs_payload(
payload=payload,
spend_logs_url=os.getenv("SPEND_LOGS_URL"),
prisma_client=prisma_client,
)
except Exception as e:
verbose_proxy_logger.debug(
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
)
raise e
async def _set_spend_logs_payload(
self,
payload: Union[dict, SpendLogsPayload],
prisma_client: PrismaClient,
spend_logs_url: Optional[str] = None,
) -> PrismaClient:
prisma_client: Optional[PrismaClient] = None,
spend_logs_url: Optional[str] = os.getenv("SPEND_LOGS_URL"),
) -> Optional[PrismaClient]:
verbose_proxy_logger.info(
"Writing spend log to db - request_id: {}, spend: {}".format(
payload.get("request_id"), payload.get("spend")
)
)
if prisma_client is not None and spend_logs_url is not None:
if isinstance(payload["startTime"], datetime):
payload["startTime"] = payload["startTime"].isoformat()
if isinstance(payload["endTime"], datetime):
payload["endTime"] = payload["endTime"].isoformat()
prisma_client.spend_log_transactions.append(payload)
elif prisma_client is not None:
prisma_client.spend_log_transactions.append(payload)
else:
verbose_proxy_logger.debug(
"prisma_client is None. Skipping writing spend logs to db."
)
await self.add_spend_log_transaction_to_daily_user_transaction(
payload=payload.copy(),
prisma_client=prisma_client,
)
return prisma_client
async def db_update_spend_transaction_handler(
@ -850,7 +833,7 @@ class DBSpendUpdateWriter:
async def add_spend_log_transaction_to_daily_user_transaction(
self,
payload: Union[dict, SpendLogsPayload],
prisma_client: PrismaClient,
prisma_client: Optional[PrismaClient] = None,
):
"""
Add a spend log transaction to the `daily_spend_update_queue`
@ -859,6 +842,11 @@ class DBSpendUpdateWriter:
If key exists, update the transaction with the new spend and usage
"""
if prisma_client is None:
verbose_proxy_logger.debug(
"prisma_client is None. Skipping writing spend logs to db."
)
return
expected_keys = ["user", "startTime", "api_key", "model", "custom_llm_provider"]
if not all(key in payload for key in expected_keys):

View file

@ -0,0 +1,67 @@
import json
import os
import sys
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
@pytest.mark.asyncio
async def test_daily_spend_tracking_with_disabled_spend_logs():
"""
Test that add_spend_log_transaction_to_daily_user_transaction is still called
even when disable_spend_logs is True
"""
# Setup
db_writer = DBSpendUpdateWriter()
# Mock the methods we want to track
db_writer._insert_spend_log_to_db = AsyncMock()
db_writer.add_spend_log_transaction_to_daily_user_transaction = AsyncMock()
# Mock the imported modules/variables
with patch("litellm.proxy.proxy_server.disable_spend_logs", True), patch(
"litellm.proxy.proxy_server.prisma_client", MagicMock()
), patch("litellm.proxy.proxy_server.user_api_key_cache", MagicMock()), patch(
"litellm.proxy.proxy_server.litellm_proxy_budget_name", "test-budget"
):
# Test data
test_data = {
"token": "test-token",
"user_id": "test-user",
"end_user_id": "test-end-user",
"start_time": datetime.now(),
"end_time": datetime.now(),
"team_id": "test-team",
"org_id": "test-org",
"completion_response": MagicMock(),
"response_cost": 0.1,
"kwargs": {"model": "gpt-4", "custom_llm_provider": "openai"},
}
# Call the method
await db_writer.update_database(**test_data)
# Verify that _insert_spend_log_to_db was NOT called (since disable_spend_logs is True)
db_writer._insert_spend_log_to_db.assert_not_called()
# Verify that add_spend_log_transaction_to_daily_user_transaction WAS called
assert db_writer.add_spend_log_transaction_to_daily_user_transaction.called
# Verify the payload passed to add_spend_log_transaction_to_daily_user_transaction
call_args = (
db_writer.add_spend_log_transaction_to_daily_user_transaction.call_args[1]
)
assert "payload" in call_args
assert call_args["payload"]["spend"] == 0.1
assert call_args["payload"]["model"] == "gpt-4"
assert call_args["payload"]["custom_llm_provider"] == "openai"

View file

@ -422,7 +422,7 @@ class TestSpendLogsPayload:
with patch.object(
litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter,
"_set_spend_logs_payload",
"_insert_spend_log_to_db",
) as mock_client, patch.object(litellm.proxy.proxy_server, "prisma_client"):
response = await litellm.acompletion(
model="gpt-4o",
@ -516,7 +516,7 @@ class TestSpendLogsPayload:
with patch.object(
litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter,
"_set_spend_logs_payload",
"_insert_spend_log_to_db",
) as mock_client, patch.object(
litellm.proxy.proxy_server, "prisma_client"
), patch.object(
@ -612,7 +612,7 @@ class TestSpendLogsPayload:
with patch.object(
litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter,
"_set_spend_logs_payload",
"_insert_spend_log_to_db",
) as mock_client, patch.object(
litellm.proxy.proxy_server, "prisma_client"
), patch.object(

View file

@ -2289,7 +2289,7 @@ async def test_update_logs_with_spend_logs_url(prisma_client):
db_spend_update_writer = DBSpendUpdateWriter()
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
await db_spend_update_writer._set_spend_logs_payload(payload=payload, prisma_client=prisma_client)
await db_spend_update_writer._insert_spend_log_to_db(payload=payload, prisma_client=prisma_client)
assert len(prisma_client.spend_log_transactions) > 0
@ -2297,7 +2297,7 @@ async def test_update_logs_with_spend_logs_url(prisma_client):
spend_logs_url = ""
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
await db_spend_update_writer._set_spend_logs_payload(
await db_spend_update_writer._insert_spend_log_to_db(
payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client
)