mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
213 lines
6.9 KiB
Python
213 lines
6.9 KiB
Python
"""
|
|
BETA
|
|
|
|
This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub.
|
|
|
|
Users can use this instead of sending their SpendLogs to their Postgres database.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import traceback
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
|
|
from litellm.types.utils import StandardLoggingPayload
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.proxy._types import SpendLogsPayload
|
|
else:
|
|
SpendLogsPayload = Any
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
get_async_httpx_client,
|
|
httpxSpecialProvider,
|
|
)
|
|
|
|
|
|
class GcsPubSubLogger(CustomBatchLogger):
|
|
def __init__(
|
|
self,
|
|
project_id: Optional[str] = None,
|
|
topic_id: Optional[str] = None,
|
|
credentials_path: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize Google Cloud Pub/Sub publisher
|
|
|
|
Args:
|
|
project_id (str): Google Cloud project ID
|
|
topic_id (str): Pub/Sub topic ID
|
|
credentials_path (str, optional): Path to Google Cloud credentials JSON file
|
|
"""
|
|
from litellm.proxy.utils import _premium_user_check
|
|
|
|
_premium_user_check()
|
|
|
|
self.async_httpx_client = get_async_httpx_client(
|
|
llm_provider=httpxSpecialProvider.LoggingCallback
|
|
)
|
|
|
|
self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID")
|
|
self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID")
|
|
self.path_service_account_json = credentials_path or os.getenv(
|
|
"GCS_PATH_SERVICE_ACCOUNT"
|
|
)
|
|
|
|
if not self.project_id or not self.topic_id:
|
|
raise ValueError("Both project_id and topic_id must be provided")
|
|
|
|
self.flush_lock = asyncio.Lock()
|
|
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
|
asyncio.create_task(self.periodic_flush())
|
|
self.log_queue: List[Union[SpendLogsPayload, StandardLoggingPayload]] = []
|
|
|
|
async def construct_request_headers(self) -> Dict[str, str]:
|
|
"""Construct authorization headers using Vertex AI auth"""
|
|
from litellm import vertex_chat_completion
|
|
|
|
_auth_header, vertex_project = (
|
|
await vertex_chat_completion._ensure_access_token_async(
|
|
credentials=self.path_service_account_json,
|
|
project_id=None,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
)
|
|
|
|
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
|
model="pub-sub",
|
|
auth_header=_auth_header,
|
|
vertex_credentials=self.path_service_account_json,
|
|
vertex_project=vertex_project,
|
|
vertex_location=None,
|
|
gemini_api_key=None,
|
|
stream=None,
|
|
custom_llm_provider="vertex_ai",
|
|
api_base=None,
|
|
)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {auth_header}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
return headers
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
"""
|
|
Async Log success events to GCS PubSub Topic
|
|
|
|
- Creates a SpendLogsPayload
|
|
- Adds to batch queue
|
|
- Flushes based on CustomBatchLogger settings
|
|
|
|
Raises:
|
|
Raises a NON Blocking verbose_logger.exception if an error occurs
|
|
"""
|
|
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
|
get_logging_payload,
|
|
)
|
|
from litellm.proxy.utils import _premium_user_check
|
|
|
|
_premium_user_check()
|
|
|
|
try:
|
|
verbose_logger.debug(
|
|
"PubSub: Logging - Enters logging function for model %s", kwargs
|
|
)
|
|
standard_logging_payload = kwargs.get("standard_logging_object", None)
|
|
|
|
# Backwards compatibility with old logging payload
|
|
if litellm.gcs_pub_sub_use_v1 is True:
|
|
spend_logs_payload = get_logging_payload(
|
|
kwargs=kwargs,
|
|
response_obj=response_obj,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
)
|
|
self.log_queue.append(spend_logs_payload)
|
|
else:
|
|
# New logging payload, StandardLoggingPayload
|
|
self.log_queue.append(standard_logging_payload)
|
|
|
|
if len(self.log_queue) >= self.batch_size:
|
|
await self.async_send_batch()
|
|
|
|
except Exception as e:
|
|
verbose_logger.exception(
|
|
f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}"
|
|
)
|
|
pass
|
|
|
|
async def async_send_batch(self):
|
|
"""
|
|
Sends the batch of messages to Pub/Sub
|
|
"""
|
|
try:
|
|
if not self.log_queue:
|
|
return
|
|
|
|
verbose_logger.debug(
|
|
f"PubSub - about to flush {len(self.log_queue)} events"
|
|
)
|
|
|
|
for message in self.log_queue:
|
|
await self.publish_message(message)
|
|
|
|
except Exception as e:
|
|
verbose_logger.exception(
|
|
f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}"
|
|
)
|
|
finally:
|
|
self.log_queue.clear()
|
|
|
|
async def publish_message(
|
|
self, message: Union[SpendLogsPayload, StandardLoggingPayload]
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Publish message to Google Cloud Pub/Sub using REST API
|
|
|
|
Args:
|
|
message: Message to publish (dict or string)
|
|
|
|
Returns:
|
|
dict: Published message response
|
|
"""
|
|
try:
|
|
headers = await self.construct_request_headers()
|
|
|
|
# Prepare message data
|
|
if isinstance(message, str):
|
|
message_data = message
|
|
else:
|
|
message_data = json.dumps(message, default=str)
|
|
|
|
# Base64 encode the message
|
|
import base64
|
|
|
|
encoded_message = base64.b64encode(message_data.encode("utf-8")).decode(
|
|
"utf-8"
|
|
)
|
|
|
|
# Construct request body
|
|
request_body = {"messages": [{"data": encoded_message}]}
|
|
|
|
url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish"
|
|
|
|
response = await self.async_httpx_client.post(
|
|
url=url, headers=headers, json=request_body
|
|
)
|
|
|
|
if response.status_code not in [200, 202]:
|
|
verbose_logger.error("Pub/Sub publish error: %s", str(response.text))
|
|
raise Exception(f"Failed to publish message: {response.text}")
|
|
|
|
verbose_logger.debug("Pub/Sub response: %s", response.text)
|
|
return response.json()
|
|
|
|
except Exception as e:
|
|
verbose_logger.error("Pub/Sub publish error: %s", str(e))
|
|
return None
|