litellm-mirror/litellm/integrations/gcs_bucket/gcs_bucket.py
Krish Dholakia e1f7bcb47d
Fix VertexAI Credential Caching issue (#9756)
* refactor(vertex_llm_base.py): Prevent credential misrouting for projects

Fixes https://github.com/BerriAI/litellm/issues/7904

* fix: passing unit tests

* fix(vertex_llm_base.py): common auth logic across sync + async vertex ai calls

prevents credential caching issue across both flows

* test: fix test

* fix(vertex_llm_base.py): handle project id in default cause

* fix(factory.py): don't pass cache control if not set

bedrock invoke does not support this

* test: fix test

* fix(vertex_llm_base.py): add .exception message in load_auth

* fix: fix ruff error
2025-04-04 16:38:08 -07:00

234 lines
8.6 KiB
Python

import asyncio
import json
import os
import uuid
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from urllib.parse import quote
from litellm._logging import verbose_logger
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.proxy._types import CommonProxyErrors
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
else:
VertexBase = Any
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
super().__init__(bucket_name=bucket_name)
# Init Batch logging settings
self.log_queue: List[GCSLogQueueItem] = []
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
self.flush_interval = int(
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
)
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
super().__init__(
flush_lock=self.flush_lock,
batch_size=self.batch_size,
flush_interval=self.flush_interval,
)
AdditionalLoggingUtils.__init__(self)
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
#### ASYNC ####
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
try:
verbose_logger.debug(
"GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# Add to logging queue - this will be flushed periodically
self.log_queue.append(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug(
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
# Add to logging queue - this will be flushed periodically
self.log_queue.append(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_send_batch(self):
"""
Process queued logs in batch - sends logs to GCS Bucket
GCS Bucket does not have a Batch endpoint to batch upload logs
Instead, we
- collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds
- during async_send_batch, we make 1 POST request per log to GCS Bucket
"""
if not self.log_queue:
return
for log_item in self.log_queue:
logging_payload = log_item["payload"]
kwargs = log_item["kwargs"]
response_obj = log_item.get("response_obj", None) or {}
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
object_name = self._get_object_name(kwargs, logging_payload, response_obj)
try:
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=logging_payload,
)
except Exception as e:
# don't let one log item fail the entire batch
verbose_logger.exception(
f"GCS Bucket error logging payload to GCS bucket: {str(e)}"
)
pass
# Clear the queue after processing
self.log_queue.clear()
def _get_object_name(
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
) -> str:
"""
Get the object name to use for the current payload
"""
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
if logging_payload.get("error_str", None) is not None:
object_name = self._generate_failure_object_name(
request_date_str=current_date,
)
else:
object_name = self._generate_success_object_name(
request_date_str=current_date,
response_id=response_obj.get("id", ""),
)
# used for testing
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = _litellm_params.get("metadata", None) or {}
if "gcs_log_id" in _metadata:
object_name = _metadata["gcs_log_id"]
return object_name
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
"""
Get the request and response payload for a given `request_id`
Tries current day, next day, and previous day until it finds the payload
"""
if start_time_utc is None:
raise ValueError(
"start_time_utc is required for getting a payload from GCS Bucket"
)
# Try current day, next day, and previous day
dates_to_try = [
start_time_utc,
start_time_utc + timedelta(days=1),
start_time_utc - timedelta(days=1),
]
date_str = None
for date in dates_to_try:
try:
date_str = self._get_object_date_from_datetime(datetime_obj=date)
object_name = self._generate_success_object_name(
request_date_str=date_str,
response_id=request_id,
)
encoded_object_name = quote(object_name, safe="")
response = await self.download_gcs_object(encoded_object_name)
if response is not None:
loaded_response = json.loads(response)
return loaded_response
except Exception as e:
verbose_logger.debug(
f"Failed to fetch payload for date {date_str}: {str(e)}"
)
continue
return None
def _generate_success_object_name(
self,
request_date_str: str,
response_id: str,
) -> str:
return f"{request_date_str}/{response_id}"
def _generate_failure_object_name(
self,
request_date_str: str,
) -> str:
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
return datetime_obj.strftime("%Y-%m-%d")
async def async_health_check(self) -> IntegrationHealthCheckStatus:
raise NotImplementedError("GCS Bucket does not support health check")