(Feat) 273% improvement GCS Bucket Logger - use Batched Logging (#6679)

* use CustomBatchLogger for GCS

* add GCS bucket logging type

* use batch logging for GCs bucket

* add gcs_bucket

* allow setting flush_interval on CustomBatchLogger

* set GCS_FLUSH_INTERVAL to 1s

* fix test_key_logging

* fix test_key_logging

* add docs on new env vars
This commit is contained in:
Ishaan Jaff 2024-11-10 22:05:34 -08:00 committed by GitHub
parent 70aa85af1f
commit eb92ed4156
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 128 additions and 72 deletions

View file

@ -1006,6 +1006,7 @@ jobs:
-e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AWS_REGION_NAME=$AWS_REGION_NAME \
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \ -e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
-e COHERE_API_KEY=$COHERE_API_KEY \ -e COHERE_API_KEY=$COHERE_API_KEY \
-e GCS_FLUSH_INTERVAL="1" \
--name my-app \ --name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \ -v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
-v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \ -v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \

View file

@ -934,6 +934,8 @@ router_settings:
| EMAIL_SUPPORT_CONTACT | Support contact email address | EMAIL_SUPPORT_CONTACT | Support contact email address
| GCS_BUCKET_NAME | Name of the Google Cloud Storage bucket | GCS_BUCKET_NAME | Name of the Google Cloud Storage bucket
| GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file | GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file
| GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS.
| GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs.
| GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers | GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers
| GENERIC_CLIENT_ID | Client ID for generic OAuth providers | GENERIC_CLIENT_ID | Client ID for generic OAuth providers
| GENERIC_CLIENT_SECRET | Client secret for generic OAuth providers | GENERIC_CLIENT_SECRET | Client secret for generic OAuth providers

View file

@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
self, self,
flush_lock: Optional[asyncio.Lock] = None, flush_lock: Optional[asyncio.Lock] = None,
batch_size: Optional[int] = DEFAULT_BATCH_SIZE, batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
""" """
self.log_queue: List = [] self.log_queue: List = []
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
self.last_flush_time = time.time() self.last_flush_time = time.time()
self.flush_lock = flush_lock self.flush_lock = flush_lock

View file

@ -1,3 +1,4 @@
import asyncio
import json import json
import os import os
import uuid import uuid
@ -10,10 +11,12 @@ from pydantic import BaseModel, Field
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import ( from litellm.types.utils import (
StandardCallbackDynamicParams, StandardCallbackDynamicParams,
StandardLoggingMetadata, StandardLoggingMetadata,
@ -27,12 +30,8 @@ else:
IAM_AUTH_KEY = "IAM_AUTH" IAM_AUTH_KEY = "IAM_AUTH"
GCS_DEFAULT_BATCH_SIZE = 2048
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
class GCSLoggingConfig(TypedDict):
bucket_name: str
vertex_instance: VertexBase
path_service_account: Optional[str]
class GCSBucketLogger(GCSBucketBase): class GCSBucketLogger(GCSBucketBase):
@ -41,6 +40,21 @@ class GCSBucketLogger(GCSBucketBase):
super().__init__(bucket_name=bucket_name) super().__init__(bucket_name=bucket_name)
self.vertex_instances: Dict[str, VertexBase] = {} self.vertex_instances: Dict[str, VertexBase] = {}
# Init Batch logging settings
self.log_queue: List[GCSLogQueueItem] = []
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
self.flush_interval = int(
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
)
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
super().__init__(
flush_lock=self.flush_lock,
batch_size=self.batch_size,
flush_interval=self.flush_interval,
)
if premium_user is not True: if premium_user is not True:
raise ValueError( raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
@ -60,44 +74,23 @@ class GCSBucketLogger(GCSBucketBase):
kwargs, kwargs,
response_obj, response_obj,
) )
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
logging_payload: Optional[StandardLoggingPayload] = kwargs.get( logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None "standard_logging_object", None
) )
if logging_payload is None: if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs") raise ValueError("standard_logging_object not found in kwargs")
# Get the current date # Add to logging queue - this will be flushed periodically
current_date = datetime.now().strftime("%Y-%m-%d") self.log_queue.append(
GCSLogQueueItem(
# Modify the object_name to include the date-based folder payload=logging_payload, kwargs=kwargs, response_obj=response_obj
object_name = f"{current_date}/{response_obj['id']}" )
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=logging_payload,
) )
except Exception as e: except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
try: try:
verbose_logger.debug( verbose_logger.debug(
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s", "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
@ -105,44 +98,77 @@ class GCSBucketLogger(GCSBucketBase):
response_obj, response_obj,
) )
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
logging_payload: Optional[StandardLoggingPayload] = kwargs.get( logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None "standard_logging_object", None
) )
if logging_payload is None: if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs") raise ValueError("standard_logging_object not found in kwargs")
_litellm_params = kwargs.get("litellm_params") or {} # Add to logging queue - this will be flushed periodically
metadata = _litellm_params.get("metadata") or {} self.log_queue.append(
GCSLogQueueItem(
# Get the current date payload=logging_payload, kwargs=kwargs, response_obj=response_obj
current_date = datetime.now().strftime("%Y-%m-%d") )
# Modify the object_name to include the date-based folder
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
if "gcs_log_id" in metadata:
object_name = metadata["gcs_log_id"]
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=logging_payload,
) )
except Exception as e: except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_send_batch(self):
"""Process queued logs in batch - sends logs to GCS Bucket"""
if not self.log_queue:
return
try:
for log_item in self.log_queue:
logging_payload = log_item["payload"]
kwargs = log_item["kwargs"]
response_obj = log_item.get("response_obj", None) or {}
gcs_logging_config: GCSLoggingConfig = (
await self.get_gcs_logging_config(kwargs)
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
object_name = self._get_object_name(
kwargs, logging_payload, response_obj
)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=logging_payload,
)
# Clear the queue after processing
self.log_queue.clear()
except Exception as e:
verbose_logger.exception(f"GCS Bucket batch logging error: {str(e)}")
def _get_object_name(
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
) -> str:
"""
Get the object name to use for the current payload
"""
current_date = datetime.now().strftime("%Y-%m-%d")
if logging_payload.get("error_str", None) is not None:
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
else:
object_name = f"{current_date}/{response_obj.get('id', '')}"
# used for testing
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = _litellm_params.get("metadata", None) or {}
if "gcs_log_id" in _metadata:
object_name = _metadata["gcs_log_id"]
return object_name
def _handle_folders_in_bucket_name( def _handle_folders_in_bucket_name(
self, self,
bucket_name: str, bucket_name: str,

View file

@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
httpxSpecialProvider, httpxSpecialProvider,
@ -21,8 +21,8 @@ else:
VertexBase = Any VertexBase = Any
class GCSBucketBase(CustomLogger): class GCSBucketBase(CustomBatchLogger):
def __init__(self, bucket_name: Optional[str] = None) -> None: def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
self.async_httpx_client = get_async_httpx_client( self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback llm_provider=httpxSpecialProvider.LoggingCallback
) )
@ -30,6 +30,7 @@ class GCSBucketBase(CustomLogger):
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME") _bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
self.path_service_account_json: Optional[str] = _path_service_account self.path_service_account_json: Optional[str] = _path_service_account
self.BUCKET_NAME: Optional[str] = _bucket_name self.BUCKET_NAME: Optional[str] = _bucket_name
super().__init__(**kwargs)
async def construct_request_headers( async def construct_request_headers(
self, self,

View file

@ -1599,7 +1599,9 @@ async def test_key_logging(
details=f"Logging test failed: {str(e)}", details=f"Logging test failed: {str(e)}",
) )
await asyncio.sleep(1) # wait for callbacks to run await asyncio.sleep(
2
) # wait for callbacks to run, callbacks use batching so wait for the flush event
# Check if any logger exceptions were triggered # Check if any logger exceptions were triggered
log_contents = log_capture_string.getvalue() log_contents = log_capture_string.getvalue()

View file

@ -7,10 +7,4 @@ model_list:
litellm_settings: litellm_settings:
callbacks: ["prometheus"] callbacks: ["gcs_bucket"]
service_callback: ["prometheus_system"]
general_settings:
allow_requests_on_db_unavailable: true

View file

@ -0,0 +1,28 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
else:
VertexBase = Any
class GCSLoggingConfig(TypedDict):
"""
Internal LiteLLM Config for GCS Bucket logging
"""
bucket_name: str
vertex_instance: VertexBase
path_service_account: Optional[str]
class GCSLogQueueItem(TypedDict):
"""
Internal Type, used for queueing logs to be sent to GCS Bucket
"""
payload: StandardLoggingPayload
kwargs: Dict[str, Any]
response_obj: Optional[Any]

View file

@ -28,6 +28,7 @@ verbose_logger.setLevel(logging.DEBUG)
def load_vertex_ai_credentials(): def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file # Define the path to the vertex_key.json file
print("loading vertex ai credentials") print("loading vertex ai credentials")
os.environ["GCS_FLUSH_INTERVAL"] = "1"
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json" vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json"