forked from phoenix/litellm-mirror
(Feat) 273% improvement GCS Bucket Logger - use Batched Logging (#6679)
* use CustomBatchLogger for GCS * add GCS bucket logging type * use batch logging for GCs bucket * add gcs_bucket * allow setting flush_interval on CustomBatchLogger * set GCS_FLUSH_INTERVAL to 1s * fix test_key_logging * fix test_key_logging * add docs on new env vars
This commit is contained in:
parent
70aa85af1f
commit
eb92ed4156
9 changed files with 128 additions and 72 deletions
|
@ -1006,6 +1006,7 @@ jobs:
|
|||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
|
||||
-e COHERE_API_KEY=$COHERE_API_KEY \
|
||||
-e GCS_FLUSH_INTERVAL="1" \
|
||||
--name my-app \
|
||||
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
|
||||
-v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \
|
||||
|
|
|
@ -934,6 +934,8 @@ router_settings:
|
|||
| EMAIL_SUPPORT_CONTACT | Support contact email address
|
||||
| GCS_BUCKET_NAME | Name of the Google Cloud Storage bucket
|
||||
| GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file
|
||||
| GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS.
|
||||
| GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs.
|
||||
| GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers
|
||||
| GENERIC_CLIENT_ID | Client ID for generic OAuth providers
|
||||
| GENERIC_CLIENT_SECRET | Client secret for generic OAuth providers
|
||||
|
|
|
@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
|
|||
self,
|
||||
flush_lock: Optional[asyncio.Lock] = None,
|
||||
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
|
||||
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
|
|||
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
|
||||
"""
|
||||
self.log_queue: List = []
|
||||
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds
|
||||
self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
|
||||
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
|
||||
self.last_flush_time = time.time()
|
||||
self.flush_lock = flush_lock
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
@ -10,10 +11,12 @@ from pydantic import BaseModel, Field
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import (
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingMetadata,
|
||||
|
@ -27,12 +30,8 @@ else:
|
|||
|
||||
|
||||
IAM_AUTH_KEY = "IAM_AUTH"
|
||||
|
||||
|
||||
class GCSLoggingConfig(TypedDict):
|
||||
bucket_name: str
|
||||
vertex_instance: VertexBase
|
||||
path_service_account: Optional[str]
|
||||
GCS_DEFAULT_BATCH_SIZE = 2048
|
||||
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
|
||||
|
||||
|
||||
class GCSBucketLogger(GCSBucketBase):
|
||||
|
@ -41,6 +40,21 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
|
||||
super().__init__(bucket_name=bucket_name)
|
||||
self.vertex_instances: Dict[str, VertexBase] = {}
|
||||
|
||||
# Init Batch logging settings
|
||||
self.log_queue: List[GCSLogQueueItem] = []
|
||||
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
|
||||
self.flush_interval = int(
|
||||
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
|
||||
)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(
|
||||
flush_lock=self.flush_lock,
|
||||
batch_size=self.batch_size,
|
||||
flush_interval=self.flush_interval,
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
|
@ -60,44 +74,23 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
# Get the current date
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Modify the object_name to include the date-based folder
|
||||
object_name = f"{current_date}/{response_obj['id']}"
|
||||
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=logging_payload,
|
||||
# Add to logging queue - this will be flushed periodically
|
||||
self.log_queue.append(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
|
||||
|
@ -105,44 +98,77 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
response_obj,
|
||||
)
|
||||
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
_litellm_params = kwargs.get("litellm_params") or {}
|
||||
metadata = _litellm_params.get("metadata") or {}
|
||||
|
||||
# Get the current date
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Modify the object_name to include the date-based folder
|
||||
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
|
||||
|
||||
if "gcs_log_id" in metadata:
|
||||
object_name = metadata["gcs_log_id"]
|
||||
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=logging_payload,
|
||||
# Add to logging queue - this will be flushed periodically
|
||||
self.log_queue.append(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""Process queued logs in batch - sends logs to GCS Bucket"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
try:
|
||||
for log_item in self.log_queue:
|
||||
logging_payload = log_item["payload"]
|
||||
kwargs = log_item["kwargs"]
|
||||
response_obj = log_item.get("response_obj", None) or {}
|
||||
|
||||
gcs_logging_config: GCSLoggingConfig = (
|
||||
await self.get_gcs_logging_config(kwargs)
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
object_name = self._get_object_name(
|
||||
kwargs, logging_payload, response_obj
|
||||
)
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=logging_payload,
|
||||
)
|
||||
|
||||
# Clear the queue after processing
|
||||
self.log_queue.clear()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket batch logging error: {str(e)}")
|
||||
|
||||
def _get_object_name(
|
||||
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
|
||||
) -> str:
|
||||
"""
|
||||
Get the object name to use for the current payload
|
||||
"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
if logging_payload.get("error_str", None) is not None:
|
||||
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
|
||||
else:
|
||||
object_name = f"{current_date}/{response_obj.get('id', '')}"
|
||||
|
||||
# used for testing
|
||||
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||
_metadata = _litellm_params.get("metadata", None) or {}
|
||||
if "gcs_log_id" in _metadata:
|
||||
object_name = _metadata["gcs_log_id"]
|
||||
|
||||
return object_name
|
||||
|
||||
def _handle_folders_in_bucket_name(
|
||||
self,
|
||||
bucket_name: str,
|
||||
|
|
|
@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
|
@ -21,8 +21,8 @@ else:
|
|||
VertexBase = Any
|
||||
|
||||
|
||||
class GCSBucketBase(CustomLogger):
|
||||
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
||||
class GCSBucketBase(CustomBatchLogger):
|
||||
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ class GCSBucketBase(CustomLogger):
|
|||
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
||||
self.path_service_account_json: Optional[str] = _path_service_account
|
||||
self.BUCKET_NAME: Optional[str] = _bucket_name
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def construct_request_headers(
|
||||
self,
|
||||
|
|
|
@ -1599,7 +1599,9 @@ async def test_key_logging(
|
|||
details=f"Logging test failed: {str(e)}",
|
||||
)
|
||||
|
||||
await asyncio.sleep(1) # wait for callbacks to run
|
||||
await asyncio.sleep(
|
||||
2
|
||||
) # wait for callbacks to run, callbacks use batching so wait for the flush event
|
||||
|
||||
# Check if any logger exceptions were triggered
|
||||
log_contents = log_capture_string.getvalue()
|
||||
|
|
|
@ -7,10 +7,4 @@ model_list:
|
|||
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
||||
service_callback: ["prometheus_system"]
|
||||
|
||||
|
||||
general_settings:
|
||||
allow_requests_on_db_unavailable: true
|
||||
|
||||
callbacks: ["gcs_bucket"]
|
||||
|
|
28
litellm/types/integrations/gcs_bucket.py
Normal file
28
litellm/types/integrations/gcs_bucket.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
|
||||
else:
|
||||
VertexBase = Any
|
||||
|
||||
|
||||
class GCSLoggingConfig(TypedDict):
|
||||
"""
|
||||
Internal LiteLLM Config for GCS Bucket logging
|
||||
"""
|
||||
|
||||
bucket_name: str
|
||||
vertex_instance: VertexBase
|
||||
path_service_account: Optional[str]
|
||||
|
||||
|
||||
class GCSLogQueueItem(TypedDict):
|
||||
"""
|
||||
Internal Type, used for queueing logs to be sent to GCS Bucket
|
||||
"""
|
||||
|
||||
payload: StandardLoggingPayload
|
||||
kwargs: Dict[str, Any]
|
||||
response_obj: Optional[Any]
|
|
@ -28,6 +28,7 @@ verbose_logger.setLevel(logging.DEBUG)
|
|||
def load_vertex_ai_credentials():
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
os.environ["GCS_FLUSH_INTERVAL"] = "1"
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json"
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue