Merge pull request #9508 from BerriAI/litellm_fix_gcs_pub_sub

[Fix] Use StandardLoggingPayload for GCS Pub Sub Logging Integration
This commit is contained in:
Ishaan Jaff 2025-03-24 18:22:43 -07:00 committed by GitHub
commit d17ab7da2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 312 additions and 12 deletions

View file

@ -122,6 +122,9 @@ langsmith_batch_size: Optional[int] = None
prometheus_initialize_budget_metrics: Optional[bool] = False
argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
gcs_pub_sub_use_v1: Optional[bool] = (
False # if you want to use v1 gcs pubsub logged payload
)
argilla_transformation_object: Optional[Dict[str, Any]] = None
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
[]

View file

@ -10,13 +10,16 @@ import asyncio
import json
import os
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from litellm.proxy._types import SpendLogsPayload
else:
SpendLogsPayload = Any
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
@ -61,7 +64,7 @@ class GcsPubSubLogger(CustomBatchLogger):
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
asyncio.create_task(self.periodic_flush())
self.log_queue: List[SpendLogsPayload] = []
self.log_queue: List[Union[SpendLogsPayload, StandardLoggingPayload]] = []
async def construct_request_headers(self) -> Dict[str, str]:
"""Construct authorization headers using Vertex AI auth"""
@ -115,13 +118,20 @@ class GcsPubSubLogger(CustomBatchLogger):
verbose_logger.debug(
"PubSub: Logging - Enters logging function for model %s", kwargs
)
spend_logs_payload = get_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
)
self.log_queue.append(spend_logs_payload)
standard_logging_payload = kwargs.get("standard_logging_object", None)
# Backwards compatibility with old logging payload
if litellm.gcs_pub_sub_use_v1 is True:
spend_logs_payload = get_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
)
self.log_queue.append(spend_logs_payload)
else:
# New logging payload, StandardLoggingPayload
self.log_queue.append(standard_logging_payload)
if len(self.log_queue) >= self.batch_size:
await self.async_send_batch()
@ -155,7 +165,7 @@ class GcsPubSubLogger(CustomBatchLogger):
self.log_queue.clear()
async def publish_message(
self, message: SpendLogsPayload
self, message: Union[SpendLogsPayload, StandardLoggingPayload]
) -> Optional[Dict[str, Any]]:
"""
Publish message to Google Cloud Pub/Sub using REST API

View file

@ -0,0 +1,175 @@
{
"id": "chatcmpl-2299b6a2-82a3-465a-b47c-04e685a2227f",
"trace_id": null,
"call_type": "acompletion",
"cache_hit": null,
"stream": true,
"status": "success",
"custom_llm_provider": "openai",
"saved_cache_cost": 0.0,
"startTime": "2025-01-24 09:20:46.847371",
"endTime": "2025-01-24 09:20:46.851954",
"completionStartTime": "2025-01-24 09:20:46.851954",
"response_time": 0.007394075393676758,
"model": "gpt-4o",
"metadata": {
"user_api_key_hash": null,
"user_api_key_alias": null,
"user_api_key_team_id": null,
"user_api_key_org_id": null,
"user_api_key_user_id": null,
"user_api_key_team_alias": null,
"user_api_key_user_email": null,
"spend_logs_metadata": null,
"requester_ip_address": null,
"requester_metadata": null,
"user_api_key_end_user_id": null,
"prompt_management_metadata": null,
"applied_guardrails": []
},
"cache_key": null,
"response_cost": 0.00022500000000000002,
"total_tokens": 30,
"prompt_tokens": 10,
"completion_tokens": 20,
"request_tags": [],
"end_user": "",
"api_base": "",
"model_group": "",
"model_id": "",
"requester_ip_address": null,
"messages": [
{
"role": "user",
"content": "Hello, world!"
}
],
"response": {
"id": "chatcmpl-2299b6a2-82a3-465a-b47c-04e685a2227f",
"created": 1742855151,
"model": "gpt-4o",
"object": "chat.completion",
"system_fingerprint": null,
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "hi",
"role": "assistant",
"tool_calls": null,
"function_call": null
}
}
],
"usage": {
"completion_tokens": 20,
"prompt_tokens": 10,
"total_tokens": 30,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
},
"model_parameters": {},
"hidden_params": {
"model_id": null,
"cache_key": null,
"api_base": "https://api.openai.com",
"response_cost": 0.00022500000000000002,
"additional_headers": {},
"litellm_overhead_time_ms": null,
"batch_models": null,
"litellm_model_name": "gpt-4o"
},
"model_map_information": {
"model_map_key": "gpt-4o",
"model_map_value": {
"key": "gpt-4o",
"max_tokens": 16384,
"max_input_tokens": 128000,
"max_output_tokens": 16384,
"input_cost_per_token": 2.5e-06,
"cache_creation_input_token_cost": null,
"cache_read_input_token_cost": 1.25e-06,
"input_cost_per_character": null,
"input_cost_per_token_above_128k_tokens": null,
"input_cost_per_query": null,
"input_cost_per_second": null,
"input_cost_per_audio_token": null,
"input_cost_per_token_batches": 1.25e-06,
"output_cost_per_token_batches": 5e-06,
"output_cost_per_token": 1e-05,
"output_cost_per_audio_token": null,
"output_cost_per_character": null,
"output_cost_per_token_above_128k_tokens": null,
"output_cost_per_character_above_128k_tokens": null,
"output_cost_per_second": null,
"output_cost_per_image": null,
"output_vector_size": null,
"litellm_provider": "openai",
"mode": "chat",
"supports_system_messages": true,
"supports_response_schema": true,
"supports_vision": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_assistant_prefill": false,
"supports_prompt_caching": true,
"supports_audio_input": false,
"supports_audio_output": false,
"supports_pdf_input": false,
"supports_embedding_image_input": false,
"supports_native_streaming": null,
"supports_web_search": true,
"search_context_cost_per_query": {
"search_context_size_low": 0.03,
"search_context_size_medium": 0.035,
"search_context_size_high": 0.05
},
"tpm": null,
"rpm": null,
"supported_openai_params": [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"max_completion_tokens",
"modalities",
"prediction",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
"audio",
"response_format",
"user"
]
}
},
"error_str": null,
"error_information": {
"error_code": "",
"error_class": "",
"llm_provider": "",
"traceback": "",
"error_message": ""
},
"response_cost_failure_debug_info": null,
"guardrail_information": null,
"standard_built_in_tools_params": {
"web_search_options": null,
"file_search": null
}
}

View file

@ -6,6 +6,7 @@ import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import litellm
import gzip
import json
import logging
@ -48,8 +49,15 @@ def assert_gcs_pubsub_request_matches_expected(
expected_request_body = json.load(f)
# Replace dynamic values in actual request body
time_fields = ["startTime", "endTime", "completionStartTime", "request_id"]
for field in time_fields:
dynamic_fields = [
"startTime",
"endTime",
"completionStartTime",
"request_id",
"id",
"response_time",
]
for field in dynamic_fields:
if field in actual_request_body:
actual_request_body[field] = expected_request_body[field]
@ -59,6 +67,55 @@ def assert_gcs_pubsub_request_matches_expected(
), f"Difference in request bodies: {json.dumps(actual_request_body, indent=2)} != {json.dumps(expected_request_body, indent=2)}"
def assert_gcs_pubsub_request_matches_expected_standard_logging_payload(
actual_request_body: dict,
expected_file_name: str,
):
"""
Helper function to compare actual GCS PubSub request body with expected JSON file.
Args:
actual_request_body (dict): The actual request body received from the API call
expected_file_name (str): Name of the JSON file containing expected request body
"""
# Get the current directory and read the expected request body
pwd = os.path.dirname(os.path.realpath(__file__))
expected_body_path = os.path.join(pwd, "gcs_pub_sub_body", expected_file_name)
with open(expected_body_path, "r") as f:
expected_request_body = json.load(f)
# Replace dynamic values in actual request body
FIELDS_TO_VALIDATE = [
"custom_llm_provider",
"hidden_params",
"messages",
"response",
"model",
"status",
"stream",
]
actual_request_body["response"]["id"] = expected_request_body["response"]["id"]
actual_request_body["response"]["created"] = expected_request_body["response"][
"created"
]
for field in FIELDS_TO_VALIDATE:
assert field in actual_request_body
FIELDS_EXISTENCE_CHECKS = [
"response_cost",
"response_time",
"completion_tokens",
"prompt_tokens",
"total_tokens",
]
for field in FIELDS_EXISTENCE_CHECKS:
assert field in actual_request_body
@pytest.mark.asyncio
async def test_async_gcs_pub_sub():
# Create a mock for the async_httpx_client's post method
@ -102,6 +159,61 @@ async def test_async_gcs_pub_sub():
decoded_message = base64.b64decode(encoded_message).decode("utf-8")
# Parse the JSON string into a dictionary
actual_request = json.loads(decoded_message)
print("##########\n")
print(json.dumps(actual_request, indent=4))
print("##########\n")
# Verify the request body matches expected format
assert_gcs_pubsub_request_matches_expected_standard_logging_payload(
actual_request, "standard_logging_payload.json"
)
@pytest.mark.asyncio
async def test_async_gcs_pub_sub_v1():
# Create a mock for the async_httpx_client's post method
litellm.gcs_pub_sub_use_v1 = True
mock_post = AsyncMock()
mock_post.return_value.status_code = 202
mock_post.return_value.text = "Accepted"
# Initialize the GcsPubSubLogger and set the mock
gcs_pub_sub_logger = GcsPubSubLogger(flush_interval=1)
gcs_pub_sub_logger.async_httpx_client.post = mock_post
mock_construct_request_headers = AsyncMock()
mock_construct_request_headers.return_value = {"Authorization": "Bearer mock_token"}
gcs_pub_sub_logger.construct_request_headers = mock_construct_request_headers
litellm.callbacks = [gcs_pub_sub_logger]
# Make the completion call
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="hi",
)
await asyncio.sleep(3) # Wait for async flush
# Assert httpx post was called
mock_post.assert_called_once()
# Get the actual request body from the mock
actual_url = mock_post.call_args[1]["url"]
print("sent to url", actual_url)
assert (
actual_url
== "https://pubsub.googleapis.com/v1/projects/reliableKeys/topics/litellmDB:publish"
)
actual_request = mock_post.call_args[1]["json"]
# Extract and decode the base64 encoded message
encoded_message = actual_request["messages"][0]["data"]
import base64
decoded_message = base64.b64decode(encoded_message).decode("utf-8")
# Parse the JSON string into a dictionary
actual_request = json.loads(decoded_message)
print("##########\n")