From 8857c9b97825d8f96452684c19d3b23e3cdee112 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 9 Apr 2024 17:58:18 -0700 Subject: [PATCH] test(test_key_generate_prisma.py): add better unit testing for spend logs on proxy server --- litellm/proxy/proxy_server.py | 49 ++++++++++++++--------- litellm/tests/test_key_generate_prisma.py | 22 ++++++++++ 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e9d8a0f0f..691bb1adf 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1283,6 +1283,20 @@ async def _PROXY_track_cost_callback( verbose_proxy_logger.debug("error in tracking cost callback - %s", e) +def _set_spend_logs_payload( + payload: dict, prisma_client: PrismaClient, spend_logs_url: Optional[str] = None +): + if prisma_client is not None and spend_logs_url is not None: + if isinstance(payload["startTime"], datetime): + payload["startTime"] = payload["startTime"].isoformat() + if isinstance(payload["endTime"], datetime): + payload["endTime"] = payload["endTime"].isoformat() + prisma_client.spend_log_transactions.append(payload) + elif prisma_client is not None: + prisma_client.spend_log_transactions.append(payload) + return prisma_client + + async def update_database( token, response_cost, @@ -1295,6 +1309,7 @@ async def update_database( end_time=None, ): try: + global prisma_client verbose_proxy_logger.info( f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" ) @@ -1453,26 +1468,22 @@ async def update_database( ### UPDATE SPEND LOGS ### async def _insert_spend_log_to_db(): try: - # Helper to generate payload to log - payload = get_logging_payload( - kwargs=kwargs, - response_obj=completion_response, - start_time=start_time, - end_time=end_time, - ) + global prisma_client + if prisma_client is not None: + # Helper to generate payload to log + payload = get_logging_payload( + kwargs=kwargs, + response_obj=completion_response, + start_time=start_time, + end_time=end_time, + ) - payload["spend"] = response_cost - if ( - prisma_client is not None - and os.getenv("SPEND_LOGS_URL", None) is not None - ): - if isinstance(payload["startTime"], datetime): - payload["startTime"] = payload["startTime"].isoformat() - if isinstance(payload["endTime"], datetime): - payload["endTime"] = payload["endTime"].isoformat() - prisma_client.spend_log_transactions.append(payload) - elif prisma_client is not None: - prisma_client.spend_log_transactions.append(payload) + payload["spend"] = response_cost + prisma_client = _set_spend_logs_payload( + payload=payload, + spend_logs_url=os.getenv("SPEND_LOGS_URL"), + prisma_client=prisma_client, + ) except Exception as e: verbose_proxy_logger.debug( f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index c100ac9a1..683927ca9 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -1677,6 +1677,28 @@ def test_get_bearer_token(): assert result == "sk-1234", f"Expected 'valid_token', got '{result}'" +def test_update_logs_with_spend_logs_url(prisma_client): + """ + Unit test for making sure spend logs list is still updated when url passed in + """ + from litellm.proxy.proxy_server import _set_spend_logs_payload + + payload = {"startTime": datetime.now(), "endTime": datetime.now()} + _set_spend_logs_payload(payload=payload, prisma_client=prisma_client) + + assert len(prisma_client.spend_log_transactions) > 0 + + prisma_client.spend_log_transactions = [] + + spend_logs_url = "" + payload = {"startTime": datetime.now(), "endTime": datetime.now()} + _set_spend_logs_payload( + payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client + ) + + assert len(prisma_client.spend_log_transactions) > 0 + + @pytest.mark.asyncio async def test_user_api_key_auth(prisma_client): from litellm.proxy.proxy_server import ProxyException