mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Move daily user transaction logging outside of 'disable_spend_logs' flag - different tables (#9772)
* refactor(db_spend_update_writer.py): aggregate table is entirely different * test(test_db_spend_update_writer.py): add unit test to ensure if disable_spend_logs is true daily user transactions is still logged * test: fix test
This commit is contained in:
parent
cd0a1e6000
commit
0d503ad8ad
5 changed files with 115 additions and 56 deletions
|
@ -33,8 +33,12 @@ model_list:
|
|||
litellm_settings:
|
||||
num_retries: 0
|
||||
callbacks: ["prometheus"]
|
||||
# json_logs: true
|
||||
|
||||
files_settings:
|
||||
- custom_llm_provider: gemini
|
||||
api_key: os.environ/GEMINI_API_KEY
|
||||
|
||||
|
||||
general_settings:
|
||||
disable_spend_logs: True
|
||||
disable_error_logs: True
|
|
@ -91,6 +91,23 @@ class DBSpendUpdateWriter:
|
|||
else:
|
||||
hashed_token = token
|
||||
|
||||
## CREATE SPEND LOG PAYLOAD ##
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
get_logging_payload,
|
||||
)
|
||||
|
||||
payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
payload["spend"] = response_cost or 0.0
|
||||
if isinstance(payload["startTime"], datetime):
|
||||
payload["startTime"] = payload["startTime"].isoformat()
|
||||
if isinstance(payload["endTime"], datetime):
|
||||
payload["endTime"] = payload["endTime"].isoformat()
|
||||
|
||||
asyncio.create_task(
|
||||
self._update_user_db(
|
||||
response_cost=response_cost,
|
||||
|
@ -125,11 +142,7 @@ class DBSpendUpdateWriter:
|
|||
)
|
||||
if disable_spend_logs is False:
|
||||
await self._insert_spend_log_to_db(
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
response_cost=response_cost,
|
||||
payload=payload,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
else:
|
||||
|
@ -137,6 +150,13 @@ class DBSpendUpdateWriter:
|
|||
"disable_spend_logs=True. Skipping writing spend logs to db. Other spend updates - Key/User/Team table will still occur."
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.add_spend_log_transaction_to_daily_user_transaction(
|
||||
payload=payload,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Runs spend update on all tables")
|
||||
except Exception:
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -284,62 +304,25 @@ class DBSpendUpdateWriter:
|
|||
raise e
|
||||
|
||||
async def _insert_spend_log_to_db(
|
||||
self,
|
||||
kwargs: Optional[dict],
|
||||
completion_response: Optional[Union[litellm.ModelResponse, Any, Exception]],
|
||||
start_time: Optional[datetime],
|
||||
end_time: Optional[datetime],
|
||||
response_cost: Optional[float],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
):
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
get_logging_payload,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client:
|
||||
payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
payload["spend"] = response_cost or 0.0
|
||||
await self._set_spend_logs_payload(
|
||||
payload=payload,
|
||||
spend_logs_url=os.getenv("SPEND_LOGS_URL"),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _set_spend_logs_payload(
|
||||
self,
|
||||
payload: Union[dict, SpendLogsPayload],
|
||||
prisma_client: PrismaClient,
|
||||
spend_logs_url: Optional[str] = None,
|
||||
) -> PrismaClient:
|
||||
prisma_client: Optional[PrismaClient] = None,
|
||||
spend_logs_url: Optional[str] = os.getenv("SPEND_LOGS_URL"),
|
||||
) -> Optional[PrismaClient]:
|
||||
verbose_proxy_logger.info(
|
||||
"Writing spend log to db - request_id: {}, spend: {}".format(
|
||||
payload.get("request_id"), payload.get("spend")
|
||||
)
|
||||
)
|
||||
if prisma_client is not None and spend_logs_url is not None:
|
||||
if isinstance(payload["startTime"], datetime):
|
||||
payload["startTime"] = payload["startTime"].isoformat()
|
||||
if isinstance(payload["endTime"], datetime):
|
||||
payload["endTime"] = payload["endTime"].isoformat()
|
||||
prisma_client.spend_log_transactions.append(payload)
|
||||
elif prisma_client is not None:
|
||||
prisma_client.spend_log_transactions.append(payload)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"prisma_client is None. Skipping writing spend logs to db."
|
||||
)
|
||||
|
||||
await self.add_spend_log_transaction_to_daily_user_transaction(
|
||||
payload=payload.copy(),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
return prisma_client
|
||||
|
||||
async def db_update_spend_transaction_handler(
|
||||
|
@ -850,7 +833,7 @@ class DBSpendUpdateWriter:
|
|||
async def add_spend_log_transaction_to_daily_user_transaction(
|
||||
self,
|
||||
payload: Union[dict, SpendLogsPayload],
|
||||
prisma_client: PrismaClient,
|
||||
prisma_client: Optional[PrismaClient] = None,
|
||||
):
|
||||
"""
|
||||
Add a spend log transaction to the `daily_spend_update_queue`
|
||||
|
@ -859,6 +842,11 @@ class DBSpendUpdateWriter:
|
|||
|
||||
If key exists, update the transaction with the new spend and usage
|
||||
"""
|
||||
if prisma_client is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"prisma_client is None. Skipping writing spend logs to db."
|
||||
)
|
||||
return
|
||||
expected_keys = ["user", "startTime", "api_key", "model", "custom_llm_provider"]
|
||||
|
||||
if not all(key in payload for key in expected_keys):
|
||||
|
|
67
tests/litellm/proxy/db/test_db_spend_update_writer.py
Normal file
67
tests/litellm/proxy/db/test_db_spend_update_writer.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_daily_spend_tracking_with_disabled_spend_logs():
|
||||
"""
|
||||
Test that add_spend_log_transaction_to_daily_user_transaction is still called
|
||||
even when disable_spend_logs is True
|
||||
"""
|
||||
# Setup
|
||||
db_writer = DBSpendUpdateWriter()
|
||||
|
||||
# Mock the methods we want to track
|
||||
db_writer._insert_spend_log_to_db = AsyncMock()
|
||||
db_writer.add_spend_log_transaction_to_daily_user_transaction = AsyncMock()
|
||||
|
||||
# Mock the imported modules/variables
|
||||
with patch("litellm.proxy.proxy_server.disable_spend_logs", True), patch(
|
||||
"litellm.proxy.proxy_server.prisma_client", MagicMock()
|
||||
), patch("litellm.proxy.proxy_server.user_api_key_cache", MagicMock()), patch(
|
||||
"litellm.proxy.proxy_server.litellm_proxy_budget_name", "test-budget"
|
||||
):
|
||||
# Test data
|
||||
test_data = {
|
||||
"token": "test-token",
|
||||
"user_id": "test-user",
|
||||
"end_user_id": "test-end-user",
|
||||
"start_time": datetime.now(),
|
||||
"end_time": datetime.now(),
|
||||
"team_id": "test-team",
|
||||
"org_id": "test-org",
|
||||
"completion_response": MagicMock(),
|
||||
"response_cost": 0.1,
|
||||
"kwargs": {"model": "gpt-4", "custom_llm_provider": "openai"},
|
||||
}
|
||||
|
||||
# Call the method
|
||||
await db_writer.update_database(**test_data)
|
||||
|
||||
# Verify that _insert_spend_log_to_db was NOT called (since disable_spend_logs is True)
|
||||
db_writer._insert_spend_log_to_db.assert_not_called()
|
||||
|
||||
# Verify that add_spend_log_transaction_to_daily_user_transaction WAS called
|
||||
assert db_writer.add_spend_log_transaction_to_daily_user_transaction.called
|
||||
|
||||
# Verify the payload passed to add_spend_log_transaction_to_daily_user_transaction
|
||||
call_args = (
|
||||
db_writer.add_spend_log_transaction_to_daily_user_transaction.call_args[1]
|
||||
)
|
||||
assert "payload" in call_args
|
||||
assert call_args["payload"]["spend"] == 0.1
|
||||
assert call_args["payload"]["model"] == "gpt-4"
|
||||
assert call_args["payload"]["custom_llm_provider"] == "openai"
|
|
@ -422,7 +422,7 @@ class TestSpendLogsPayload:
|
|||
|
||||
with patch.object(
|
||||
litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter,
|
||||
"_set_spend_logs_payload",
|
||||
"_insert_spend_log_to_db",
|
||||
) as mock_client, patch.object(litellm.proxy.proxy_server, "prisma_client"):
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-4o",
|
||||
|
@ -516,7 +516,7 @@ class TestSpendLogsPayload:
|
|||
|
||||
with patch.object(
|
||||
litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter,
|
||||
"_set_spend_logs_payload",
|
||||
"_insert_spend_log_to_db",
|
||||
) as mock_client, patch.object(
|
||||
litellm.proxy.proxy_server, "prisma_client"
|
||||
), patch.object(
|
||||
|
@ -612,7 +612,7 @@ class TestSpendLogsPayload:
|
|||
|
||||
with patch.object(
|
||||
litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter,
|
||||
"_set_spend_logs_payload",
|
||||
"_insert_spend_log_to_db",
|
||||
) as mock_client, patch.object(
|
||||
litellm.proxy.proxy_server, "prisma_client"
|
||||
), patch.object(
|
||||
|
|
|
@ -2289,7 +2289,7 @@ async def test_update_logs_with_spend_logs_url(prisma_client):
|
|||
db_spend_update_writer = DBSpendUpdateWriter()
|
||||
|
||||
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
|
||||
await db_spend_update_writer._set_spend_logs_payload(payload=payload, prisma_client=prisma_client)
|
||||
await db_spend_update_writer._insert_spend_log_to_db(payload=payload, prisma_client=prisma_client)
|
||||
|
||||
assert len(prisma_client.spend_log_transactions) > 0
|
||||
|
||||
|
@ -2297,7 +2297,7 @@ async def test_update_logs_with_spend_logs_url(prisma_client):
|
|||
|
||||
spend_logs_url = ""
|
||||
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
|
||||
await db_spend_update_writer._set_spend_logs_payload(
|
||||
await db_spend_update_writer._insert_spend_log_to_db(
|
||||
payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue