diff --git a/.circleci/config.yml b/.circleci/config.yml index d2d83cd0e..88e83fa7f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1006,6 +1006,7 @@ jobs: -e AWS_REGION_NAME=$AWS_REGION_NAME \ -e APORIA_API_KEY_1=$APORIA_API_KEY_1 \ -e COHERE_API_KEY=$COHERE_API_KEY \ + -e GCS_FLUSH_INTERVAL="1" \ --name my-app \ -v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \ -v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \ diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 1adc4943d..b4d70a4e7 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -934,6 +934,8 @@ router_settings: | EMAIL_SUPPORT_CONTACT | Support contact email address | GCS_BUCKET_NAME | Name of the Google Cloud Storage bucket | GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file +| GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS. +| GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs. | GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers | GENERIC_CLIENT_ID | Client ID for generic OAuth providers | GENERIC_CLIENT_SECRET | Client secret for generic OAuth providers diff --git a/litellm/integrations/custom_batch_logger.py b/litellm/integrations/custom_batch_logger.py index aa7f0bba2..7ef63d25c 100644 --- a/litellm/integrations/custom_batch_logger.py +++ b/litellm/integrations/custom_batch_logger.py @@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger): self, flush_lock: Optional[asyncio.Lock] = None, batch_size: Optional[int] = DEFAULT_BATCH_SIZE, + flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS, **kwargs, ) -> None: """ @@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger): flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching """ self.log_queue: List = [] - self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds + self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE self.last_flush_time = time.time() self.flush_lock = flush_lock diff --git a/litellm/integrations/gcs_bucket/gcs_bucket.py b/litellm/integrations/gcs_bucket/gcs_bucket.py index f7f36c124..0b637f9b6 100644 --- a/litellm/integrations/gcs_bucket/gcs_bucket.py +++ b/litellm/integrations/gcs_bucket/gcs_bucket.py @@ -1,3 +1,4 @@ +import asyncio import json import os import uuid @@ -10,10 +11,12 @@ from pydantic import BaseModel, Field import litellm from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload +from litellm.types.integrations.gcs_bucket import * from litellm.types.utils import ( StandardCallbackDynamicParams, StandardLoggingMetadata, @@ -27,12 +30,8 @@ else: IAM_AUTH_KEY = "IAM_AUTH" - - -class GCSLoggingConfig(TypedDict): - bucket_name: str - vertex_instance: VertexBase - path_service_account: Optional[str] +GCS_DEFAULT_BATCH_SIZE = 2048 +GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20 class GCSBucketLogger(GCSBucketBase): @@ -41,6 +40,21 @@ class GCSBucketLogger(GCSBucketBase): super().__init__(bucket_name=bucket_name) self.vertex_instances: Dict[str, VertexBase] = {} + + # Init Batch logging settings + self.log_queue: List[GCSLogQueueItem] = [] + self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE)) + self.flush_interval = int( + os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS) + ) + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + super().__init__( + flush_lock=self.flush_lock, + batch_size=self.batch_size, + flush_interval=self.flush_interval, + ) + if premium_user is not True: raise ValueError( f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" @@ -60,44 +74,23 @@ class GCSBucketLogger(GCSBucketBase): kwargs, response_obj, ) - gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( - kwargs - ) - headers = await self.construct_request_headers( - vertex_instance=gcs_logging_config["vertex_instance"], - service_account_json=gcs_logging_config["path_service_account"], - ) - bucket_name = gcs_logging_config["bucket_name"] - logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None ) - if logging_payload is None: raise ValueError("standard_logging_object not found in kwargs") - # Get the current date - current_date = datetime.now().strftime("%Y-%m-%d") - - # Modify the object_name to include the date-based folder - object_name = f"{current_date}/{response_obj['id']}" - - await self._log_json_data_on_gcs( - headers=headers, - bucket_name=bucket_name, - object_name=object_name, - logging_payload=logging_payload, + # Add to logging queue - this will be flushed periodically + self.log_queue.append( + GCSLogQueueItem( + payload=logging_payload, kwargs=kwargs, response_obj=response_obj + ) ) + except Exception as e: verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): - from litellm.proxy.proxy_server import premium_user - - if premium_user is not True: - raise ValueError( - f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" - ) try: verbose_logger.debug( "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s", @@ -105,44 +98,77 @@ class GCSBucketLogger(GCSBucketBase): response_obj, ) - gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( - kwargs - ) - headers = await self.construct_request_headers( - vertex_instance=gcs_logging_config["vertex_instance"], - service_account_json=gcs_logging_config["path_service_account"], - ) - bucket_name = gcs_logging_config["bucket_name"] - logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None ) - if logging_payload is None: raise ValueError("standard_logging_object not found in kwargs") - _litellm_params = kwargs.get("litellm_params") or {} - metadata = _litellm_params.get("metadata") or {} - - # Get the current date - current_date = datetime.now().strftime("%Y-%m-%d") - - # Modify the object_name to include the date-based folder - object_name = f"{current_date}/failure-{uuid.uuid4().hex}" - - if "gcs_log_id" in metadata: - object_name = metadata["gcs_log_id"] - - await self._log_json_data_on_gcs( - headers=headers, - bucket_name=bucket_name, - object_name=object_name, - logging_payload=logging_payload, + # Add to logging queue - this will be flushed periodically + self.log_queue.append( + GCSLogQueueItem( + payload=logging_payload, kwargs=kwargs, response_obj=response_obj + ) ) except Exception as e: verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") + async def async_send_batch(self): + """Process queued logs in batch - sends logs to GCS Bucket""" + if not self.log_queue: + return + + try: + for log_item in self.log_queue: + logging_payload = log_item["payload"] + kwargs = log_item["kwargs"] + response_obj = log_item.get("response_obj", None) or {} + + gcs_logging_config: GCSLoggingConfig = ( + await self.get_gcs_logging_config(kwargs) + ) + headers = await self.construct_request_headers( + vertex_instance=gcs_logging_config["vertex_instance"], + service_account_json=gcs_logging_config["path_service_account"], + ) + bucket_name = gcs_logging_config["bucket_name"] + object_name = self._get_object_name( + kwargs, logging_payload, response_obj + ) + await self._log_json_data_on_gcs( + headers=headers, + bucket_name=bucket_name, + object_name=object_name, + logging_payload=logging_payload, + ) + + # Clear the queue after processing + self.log_queue.clear() + + except Exception as e: + verbose_logger.exception(f"GCS Bucket batch logging error: {str(e)}") + + def _get_object_name( + self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any + ) -> str: + """ + Get the object name to use for the current payload + """ + current_date = datetime.now().strftime("%Y-%m-%d") + if logging_payload.get("error_str", None) is not None: + object_name = f"{current_date}/failure-{uuid.uuid4().hex}" + else: + object_name = f"{current_date}/{response_obj.get('id', '')}" + + # used for testing + _litellm_params = kwargs.get("litellm_params", None) or {} + _metadata = _litellm_params.get("metadata", None) or {} + if "gcs_log_id" in _metadata: + object_name = _metadata["gcs_log_id"] + + return object_name + def _handle_folders_in_bucket_name( self, bucket_name: str, diff --git a/litellm/integrations/gcs_bucket/gcs_bucket_base.py b/litellm/integrations/gcs_bucket/gcs_bucket_base.py index 56df3aa80..9615b9b21 100644 --- a/litellm/integrations/gcs_bucket/gcs_bucket_base.py +++ b/litellm/integrations/gcs_bucket/gcs_bucket_base.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field import litellm from litellm._logging import verbose_logger -from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.custom_batch_logger import CustomBatchLogger from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, httpxSpecialProvider, @@ -21,8 +21,8 @@ else: VertexBase = Any -class GCSBucketBase(CustomLogger): - def __init__(self, bucket_name: Optional[str] = None) -> None: +class GCSBucketBase(CustomBatchLogger): + def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None: self.async_httpx_client = get_async_httpx_client( llm_provider=httpxSpecialProvider.LoggingCallback ) @@ -30,6 +30,7 @@ class GCSBucketBase(CustomLogger): _bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME") self.path_service_account_json: Optional[str] = _path_service_account self.BUCKET_NAME: Optional[str] = _bucket_name + super().__init__(**kwargs) async def construct_request_headers( self, diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 01baa232f..2c240a17f 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -1599,7 +1599,9 @@ async def test_key_logging( details=f"Logging test failed: {str(e)}", ) - await asyncio.sleep(1) # wait for callbacks to run + await asyncio.sleep( + 2 + ) # wait for callbacks to run, callbacks use batching so wait for the flush event # Check if any logger exceptions were triggered log_contents = log_capture_string.getvalue() diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 694c1613d..b4a18baa4 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -7,10 +7,4 @@ model_list: litellm_settings: - callbacks: ["prometheus"] - service_callback: ["prometheus_system"] - - -general_settings: - allow_requests_on_db_unavailable: true - + callbacks: ["gcs_bucket"] diff --git a/litellm/types/integrations/gcs_bucket.py b/litellm/types/integrations/gcs_bucket.py new file mode 100644 index 000000000..18636ae1f --- /dev/null +++ b/litellm/types/integrations/gcs_bucket.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict + +from litellm.types.utils import StandardLoggingPayload + +if TYPE_CHECKING: + from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase +else: + VertexBase = Any + + +class GCSLoggingConfig(TypedDict): + """ + Internal LiteLLM Config for GCS Bucket logging + """ + + bucket_name: str + vertex_instance: VertexBase + path_service_account: Optional[str] + + +class GCSLogQueueItem(TypedDict): + """ + Internal Type, used for queueing logs to be sent to GCS Bucket + """ + + payload: StandardLoggingPayload + kwargs: Dict[str, Any] + response_obj: Optional[Any] diff --git a/tests/local_testing/test_gcs_bucket.py b/tests/local_testing/test_gcs_bucket.py index a01e839fa..4d431b662 100644 --- a/tests/local_testing/test_gcs_bucket.py +++ b/tests/local_testing/test_gcs_bucket.py @@ -28,6 +28,7 @@ verbose_logger.setLevel(logging.DEBUG) def load_vertex_ai_credentials(): # Define the path to the vertex_key.json file print("loading vertex ai credentials") + os.environ["GCS_FLUSH_INTERVAL"] = "1" filepath = os.path.dirname(os.path.abspath(__file__)) vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json"