import io import os import sys sys.path.insert(0, os.path.abspath("../..")) import asyncio import litellm import gzip import json import logging import time from unittest.mock import AsyncMock, patch import pytest import litellm from litellm import completion from litellm._logging import verbose_logger from litellm.integrations.gcs_pubsub.pub_sub import * from datetime import datetime, timedelta from litellm.types.utils import ( StandardLoggingPayload, StandardLoggingModelInformation, StandardLoggingMetadata, StandardLoggingHiddenParams, ) verbose_logger.setLevel(logging.DEBUG) ignored_keys = [ "request_id", "startTime", "endTime", "completionStartTime", "endTime", "metadata.model_map_information", "metadata.usage_object", ] def _compare_nested_dicts( actual: dict, expected: dict, path: str = "", ignore_keys: list[str] = [] ) -> list[str]: """Compare nested dictionaries and return a list of differences in a human-friendly format.""" differences = [] # Check if current path should be ignored if path in ignore_keys: return differences # Check for keys in actual but not in expected for key in actual.keys(): current_path = f"{path}.{key}" if path else key if current_path not in ignore_keys and key not in expected: differences.append(f"Extra key in actual: {current_path}") for key, expected_value in expected.items(): current_path = f"{path}.{key}" if path else key if current_path in ignore_keys: continue if key not in actual: differences.append(f"Missing key: {current_path}") continue actual_value = actual[key] # Try to parse JSON strings if isinstance(expected_value, str): try: expected_value = json.loads(expected_value) except json.JSONDecodeError: pass if isinstance(actual_value, str): try: actual_value = json.loads(actual_value) except json.JSONDecodeError: pass if isinstance(expected_value, dict) and isinstance(actual_value, dict): differences.extend( _compare_nested_dicts( actual_value, expected_value, current_path, ignore_keys ) ) elif isinstance(expected_value, dict) or isinstance(actual_value, dict): differences.append( f"Type mismatch at {current_path}: expected dict, got {type(actual_value).__name__}" ) else: # For non-dict values, only report if they're different if actual_value != expected_value: # Format the values to be more readable actual_str = str(actual_value) expected_str = str(expected_value) if len(actual_str) > 50 or len(expected_str) > 50: actual_str = f"{actual_str[:50]}..." expected_str = f"{expected_str[:50]}..." differences.append( f"Value mismatch at {current_path}:\n expected: {expected_str}\n got: {actual_str}" ) return differences def assert_gcs_pubsub_request_matches_expected( actual_request_body: dict, expected_file_name: str, ): """ Helper function to compare actual GCS PubSub request body with expected JSON file. Args: actual_request_body (dict): The actual request body received from the API call expected_file_name (str): Name of the JSON file containing expected request body """ # Get the current directory and read the expected request body pwd = os.path.dirname(os.path.realpath(__file__)) expected_body_path = os.path.join(pwd, "gcs_pub_sub_body", expected_file_name) with open(expected_body_path, "r") as f: expected_request_body = json.load(f) # Replace dynamic values in actual request body differences = _compare_nested_dicts( actual_request_body, expected_request_body, ignore_keys=ignored_keys ) if differences: assert False, f"Dictionary mismatch: {differences}" def assert_gcs_pubsub_request_matches_expected_standard_logging_payload( actual_request_body: dict, expected_file_name: str, ): """ Helper function to compare actual GCS PubSub request body with expected JSON file. Args: actual_request_body (dict): The actual request body received from the API call expected_file_name (str): Name of the JSON file containing expected request body """ # Get the current directory and read the expected request body pwd = os.path.dirname(os.path.realpath(__file__)) expected_body_path = os.path.join(pwd, "gcs_pub_sub_body", expected_file_name) with open(expected_body_path, "r") as f: expected_request_body = json.load(f) # Replace dynamic values in actual request body FIELDS_TO_VALIDATE = [ "custom_llm_provider", "hidden_params", "messages", "response", "model", "status", "stream", ] actual_request_body["response"]["id"] = expected_request_body["response"]["id"] actual_request_body["response"]["created"] = expected_request_body["response"][ "created" ] for field in FIELDS_TO_VALIDATE: assert field in actual_request_body FIELDS_EXISTENCE_CHECKS = [ "response_cost", "response_time", "completion_tokens", "prompt_tokens", "total_tokens", ] for field in FIELDS_EXISTENCE_CHECKS: assert field in actual_request_body @pytest.mark.asyncio async def test_async_gcs_pub_sub(): # Create a mock for the async_httpx_client's post method mock_post = AsyncMock() mock_post.return_value.status_code = 202 mock_post.return_value.text = "Accepted" # Initialize the GcsPubSubLogger and set the mock gcs_pub_sub_logger = GcsPubSubLogger(flush_interval=1) gcs_pub_sub_logger.async_httpx_client.post = mock_post mock_construct_request_headers = AsyncMock() mock_construct_request_headers.return_value = {"Authorization": "Bearer mock_token"} gcs_pub_sub_logger.construct_request_headers = mock_construct_request_headers litellm.callbacks = [gcs_pub_sub_logger] # Make the completion call response = await litellm.acompletion( model="gpt-4o", messages=[{"role": "user", "content": "Hello, world!"}], mock_response="hi", ) await asyncio.sleep(3) # Wait for async flush # Assert httpx post was called mock_post.assert_called_once() # Get the actual request body from the mock actual_url = mock_post.call_args[1]["url"] print("sent to url", actual_url) assert ( actual_url == "https://pubsub.googleapis.com/v1/projects/reliableKeys/topics/litellmDB:publish" ) actual_request = mock_post.call_args[1]["json"] # Extract and decode the base64 encoded message encoded_message = actual_request["messages"][0]["data"] import base64 decoded_message = base64.b64decode(encoded_message).decode("utf-8") # Parse the JSON string into a dictionary actual_request = json.loads(decoded_message) print("##########\n") print(json.dumps(actual_request, indent=4)) print("##########\n") # Verify the request body matches expected format assert_gcs_pubsub_request_matches_expected_standard_logging_payload( actual_request, "standard_logging_payload.json" ) @pytest.mark.asyncio async def test_async_gcs_pub_sub_v1(): # Create a mock for the async_httpx_client's post method litellm.gcs_pub_sub_use_v1 = True mock_post = AsyncMock() mock_post.return_value.status_code = 202 mock_post.return_value.text = "Accepted" # Initialize the GcsPubSubLogger and set the mock gcs_pub_sub_logger = GcsPubSubLogger(flush_interval=1) gcs_pub_sub_logger.async_httpx_client.post = mock_post mock_construct_request_headers = AsyncMock() mock_construct_request_headers.return_value = {"Authorization": "Bearer mock_token"} gcs_pub_sub_logger.construct_request_headers = mock_construct_request_headers litellm.callbacks = [gcs_pub_sub_logger] # Make the completion call response = await litellm.acompletion( model="gpt-4o", messages=[{"role": "user", "content": "Hello, world!"}], mock_response="hi", ) await asyncio.sleep(3) # Wait for async flush # Assert httpx post was called mock_post.assert_called_once() # Get the actual request body from the mock actual_url = mock_post.call_args[1]["url"] print("sent to url", actual_url) assert ( actual_url == "https://pubsub.googleapis.com/v1/projects/reliableKeys/topics/litellmDB:publish" ) actual_request = mock_post.call_args[1]["json"] # Extract and decode the base64 encoded message encoded_message = actual_request["messages"][0]["data"] import base64 decoded_message = base64.b64decode(encoded_message).decode("utf-8") # Parse the JSON string into a dictionary actual_request = json.loads(decoded_message) print("##########\n") print(json.dumps(actual_request, indent=4)) print("##########\n") # Verify the request body matches expected format assert_gcs_pubsub_request_matches_expected( actual_request, "spend_logs_payload.json" )