litellm-mirror/tests/logging_callback_tests/test_gcs_pub_sub.py

225 lines
7 KiB
Python

import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import litellm
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch
import pytest
import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.gcs_pubsub.pub_sub import *
from datetime import datetime, timedelta
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingModelInformation,
StandardLoggingMetadata,
StandardLoggingHiddenParams,
)
verbose_logger.setLevel(logging.DEBUG)
def assert_gcs_pubsub_request_matches_expected(
actual_request_body: dict,
expected_file_name: str,
):
"""
Helper function to compare actual GCS PubSub request body with expected JSON file.
Args:
actual_request_body (dict): The actual request body received from the API call
expected_file_name (str): Name of the JSON file containing expected request body
"""
# Get the current directory and read the expected request body
pwd = os.path.dirname(os.path.realpath(__file__))
expected_body_path = os.path.join(pwd, "gcs_pub_sub_body", expected_file_name)
with open(expected_body_path, "r") as f:
expected_request_body = json.load(f)
# Replace dynamic values in actual request body
dynamic_fields = [
"startTime",
"endTime",
"completionStartTime",
"request_id",
"id",
"response_time",
]
for field in dynamic_fields:
if field in actual_request_body:
actual_request_body[field] = expected_request_body[field]
# Assert the entire request body matches
assert (
actual_request_body == expected_request_body
), f"Difference in request bodies: {json.dumps(actual_request_body, indent=2)} != {json.dumps(expected_request_body, indent=2)}"
def assert_gcs_pubsub_request_matches_expected_standard_logging_payload(
actual_request_body: dict,
expected_file_name: str,
):
"""
Helper function to compare actual GCS PubSub request body with expected JSON file.
Args:
actual_request_body (dict): The actual request body received from the API call
expected_file_name (str): Name of the JSON file containing expected request body
"""
# Get the current directory and read the expected request body
pwd = os.path.dirname(os.path.realpath(__file__))
expected_body_path = os.path.join(pwd, "gcs_pub_sub_body", expected_file_name)
with open(expected_body_path, "r") as f:
expected_request_body = json.load(f)
# Replace dynamic values in actual request body
FIELDS_TO_VALIDATE = [
"custom_llm_provider",
"hidden_params",
"messages",
"response",
"model",
"status",
"stream",
]
actual_request_body["response"]["id"] = expected_request_body["response"]["id"]
actual_request_body["response"]["created"] = expected_request_body["response"][
"created"
]
for field in FIELDS_TO_VALIDATE:
assert field in actual_request_body
FIELDS_EXISTENCE_CHECKS = [
"response_cost",
"response_time",
"completion_tokens",
"prompt_tokens",
"total_tokens",
]
for field in FIELDS_EXISTENCE_CHECKS:
assert field in actual_request_body
@pytest.mark.asyncio
async def test_async_gcs_pub_sub():
# Create a mock for the async_httpx_client's post method
mock_post = AsyncMock()
mock_post.return_value.status_code = 202
mock_post.return_value.text = "Accepted"
# Initialize the GcsPubSubLogger and set the mock
gcs_pub_sub_logger = GcsPubSubLogger(flush_interval=1)
gcs_pub_sub_logger.async_httpx_client.post = mock_post
mock_construct_request_headers = AsyncMock()
mock_construct_request_headers.return_value = {"Authorization": "Bearer mock_token"}
gcs_pub_sub_logger.construct_request_headers = mock_construct_request_headers
litellm.callbacks = [gcs_pub_sub_logger]
# Make the completion call
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="hi",
)
await asyncio.sleep(3) # Wait for async flush
# Assert httpx post was called
mock_post.assert_called_once()
# Get the actual request body from the mock
actual_url = mock_post.call_args[1]["url"]
print("sent to url", actual_url)
assert (
actual_url
== "https://pubsub.googleapis.com/v1/projects/reliableKeys/topics/litellmDB:publish"
)
actual_request = mock_post.call_args[1]["json"]
# Extract and decode the base64 encoded message
encoded_message = actual_request["messages"][0]["data"]
import base64
decoded_message = base64.b64decode(encoded_message).decode("utf-8")
# Parse the JSON string into a dictionary
actual_request = json.loads(decoded_message)
print("##########\n")
print(json.dumps(actual_request, indent=4))
print("##########\n")
# Verify the request body matches expected format
assert_gcs_pubsub_request_matches_expected_standard_logging_payload(
actual_request, "standard_logging_payload.json"
)
@pytest.mark.asyncio
async def test_async_gcs_pub_sub_v1():
# Create a mock for the async_httpx_client's post method
litellm.gcs_pub_sub_use_v1 = True
mock_post = AsyncMock()
mock_post.return_value.status_code = 202
mock_post.return_value.text = "Accepted"
# Initialize the GcsPubSubLogger and set the mock
gcs_pub_sub_logger = GcsPubSubLogger(flush_interval=1)
gcs_pub_sub_logger.async_httpx_client.post = mock_post
mock_construct_request_headers = AsyncMock()
mock_construct_request_headers.return_value = {"Authorization": "Bearer mock_token"}
gcs_pub_sub_logger.construct_request_headers = mock_construct_request_headers
litellm.callbacks = [gcs_pub_sub_logger]
# Make the completion call
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="hi",
)
await asyncio.sleep(3) # Wait for async flush
# Assert httpx post was called
mock_post.assert_called_once()
# Get the actual request body from the mock
actual_url = mock_post.call_args[1]["url"]
print("sent to url", actual_url)
assert (
actual_url
== "https://pubsub.googleapis.com/v1/projects/reliableKeys/topics/litellmDB:publish"
)
actual_request = mock_post.call_args[1]["json"]
# Extract and decode the base64 encoded message
encoded_message = actual_request["messages"][0]["data"]
import base64
decoded_message = base64.b64decode(encoded_message).decode("utf-8")
# Parse the JSON string into a dictionary
actual_request = json.loads(decoded_message)
print("##########\n")
print(json.dumps(actual_request, indent=4))
print("##########\n")
# Verify the request body matches expected format
assert_gcs_pubsub_request_matches_expected(
actual_request, "spend_logs_payload.json"
)