import asyncio import json import os import uuid from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional from urllib.parse import quote from litellm._logging import verbose_logger from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase from litellm.proxy._types import CommonProxyErrors from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus from litellm.types.integrations.gcs_bucket import * from litellm.types.utils import StandardLoggingPayload if TYPE_CHECKING: from litellm.llms.vertex_ai.vertex_llm_base import VertexBase else: VertexBase = Any class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): def __init__(self, bucket_name: Optional[str] = None) -> None: from litellm.proxy.proxy_server import premium_user super().__init__(bucket_name=bucket_name) # Init Batch logging settings self.log_queue: List[GCSLogQueueItem] = [] self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE)) self.flush_interval = int( os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS) ) asyncio.create_task(self.periodic_flush()) self.flush_lock = asyncio.Lock() super().__init__( flush_lock=self.flush_lock, batch_size=self.batch_size, flush_interval=self.flush_interval, ) AdditionalLoggingUtils.__init__(self) if premium_user is not True: raise ValueError( f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" ) #### ASYNC #### async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): from litellm.proxy.proxy_server import premium_user if premium_user is not True: raise ValueError( f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" ) try: verbose_logger.debug( "GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s", kwargs, response_obj, ) logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None ) if logging_payload is None: raise ValueError("standard_logging_object not found in kwargs") # Add to logging queue - this will be flushed periodically self.log_queue.append( GCSLogQueueItem( payload=logging_payload, kwargs=kwargs, response_obj=response_obj ) ) except Exception as e: verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: verbose_logger.debug( "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s", kwargs, response_obj, ) logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None ) if logging_payload is None: raise ValueError("standard_logging_object not found in kwargs") # Add to logging queue - this will be flushed periodically self.log_queue.append( GCSLogQueueItem( payload=logging_payload, kwargs=kwargs, response_obj=response_obj ) ) except Exception as e: verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") async def async_send_batch(self): """ Process queued logs in batch - sends logs to GCS Bucket GCS Bucket does not have a Batch endpoint to batch upload logs Instead, we - collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds - during async_send_batch, we make 1 POST request per log to GCS Bucket """ if not self.log_queue: return for log_item in self.log_queue: logging_payload = log_item["payload"] kwargs = log_item["kwargs"] response_obj = log_item.get("response_obj", None) or {} gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( kwargs ) headers = await self.construct_request_headers( vertex_instance=gcs_logging_config["vertex_instance"], service_account_json=gcs_logging_config["path_service_account"], ) bucket_name = gcs_logging_config["bucket_name"] object_name = self._get_object_name(kwargs, logging_payload, response_obj) try: await self._log_json_data_on_gcs( headers=headers, bucket_name=bucket_name, object_name=object_name, logging_payload=logging_payload, ) except Exception as e: # don't let one log item fail the entire batch verbose_logger.exception( f"GCS Bucket error logging payload to GCS bucket: {str(e)}" ) pass # Clear the queue after processing self.log_queue.clear() def _get_object_name( self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any ) -> str: """ Get the object name to use for the current payload """ current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc)) if logging_payload.get("error_str", None) is not None: object_name = self._generate_failure_object_name( request_date_str=current_date, ) else: object_name = self._generate_success_object_name( request_date_str=current_date, response_id=response_obj.get("id", ""), ) # used for testing _litellm_params = kwargs.get("litellm_params", None) or {} _metadata = _litellm_params.get("metadata", None) or {} if "gcs_log_id" in _metadata: object_name = _metadata["gcs_log_id"] return object_name async def get_request_response_payload( self, request_id: str, start_time_utc: Optional[datetime], end_time_utc: Optional[datetime], ) -> Optional[dict]: """ Get the request and response payload for a given `request_id` Tries current day, next day, and previous day until it finds the payload """ if start_time_utc is None: raise ValueError( "start_time_utc is required for getting a payload from GCS Bucket" ) # Try current day, next day, and previous day dates_to_try = [ start_time_utc, start_time_utc + timedelta(days=1), start_time_utc - timedelta(days=1), ] date_str = None for date in dates_to_try: try: date_str = self._get_object_date_from_datetime(datetime_obj=date) object_name = self._generate_success_object_name( request_date_str=date_str, response_id=request_id, ) encoded_object_name = quote(object_name, safe="") response = await self.download_gcs_object(encoded_object_name) if response is not None: loaded_response = json.loads(response) return loaded_response except Exception as e: verbose_logger.debug( f"Failed to fetch payload for date {date_str}: {str(e)}" ) continue return None def _generate_success_object_name( self, request_date_str: str, response_id: str, ) -> str: return f"{request_date_str}/{response_id}" def _generate_failure_object_name( self, request_date_str: str, ) -> str: return f"{request_date_str}/failure-{uuid.uuid4().hex}" def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str: return datetime_obj.strftime("%Y-%m-%d") async def async_health_check(self) -> IntegrationHealthCheckStatus: raise NotImplementedError("GCS Bucket does not support health check")