forked from phoenix/litellm-mirror
(Feat) 273% improvement GCS Bucket Logger - use Batched Logging (#6679)
* use CustomBatchLogger for GCS * add GCS bucket logging type * use batch logging for GCs bucket * add gcs_bucket * allow setting flush_interval on CustomBatchLogger * set GCS_FLUSH_INTERVAL to 1s * fix test_key_logging * fix test_key_logging * add docs on new env vars
This commit is contained in:
parent
70aa85af1f
commit
eb92ed4156
9 changed files with 128 additions and 72 deletions
|
@ -1006,6 +1006,7 @@ jobs:
|
||||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
|
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
|
||||||
-e COHERE_API_KEY=$COHERE_API_KEY \
|
-e COHERE_API_KEY=$COHERE_API_KEY \
|
||||||
|
-e GCS_FLUSH_INTERVAL="1" \
|
||||||
--name my-app \
|
--name my-app \
|
||||||
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
|
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
|
||||||
-v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \
|
-v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \
|
||||||
|
|
|
@ -934,6 +934,8 @@ router_settings:
|
||||||
| EMAIL_SUPPORT_CONTACT | Support contact email address
|
| EMAIL_SUPPORT_CONTACT | Support contact email address
|
||||||
| GCS_BUCKET_NAME | Name of the Google Cloud Storage bucket
|
| GCS_BUCKET_NAME | Name of the Google Cloud Storage bucket
|
||||||
| GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file
|
| GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file
|
||||||
|
| GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS.
|
||||||
|
| GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs.
|
||||||
| GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers
|
| GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers
|
||||||
| GENERIC_CLIENT_ID | Client ID for generic OAuth providers
|
| GENERIC_CLIENT_ID | Client ID for generic OAuth providers
|
||||||
| GENERIC_CLIENT_SECRET | Client secret for generic OAuth providers
|
| GENERIC_CLIENT_SECRET | Client secret for generic OAuth providers
|
||||||
|
|
|
@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
|
||||||
self,
|
self,
|
||||||
flush_lock: Optional[asyncio.Lock] = None,
|
flush_lock: Optional[asyncio.Lock] = None,
|
||||||
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
|
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
|
||||||
|
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
|
||||||
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
|
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
|
||||||
"""
|
"""
|
||||||
self.log_queue: List = []
|
self.log_queue: List = []
|
||||||
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds
|
self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
|
||||||
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
|
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
|
||||||
self.last_flush_time = time.time()
|
self.last_flush_time = time.time()
|
||||||
self.flush_lock = flush_lock
|
self.flush_lock = flush_lock
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -10,10 +11,12 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
|
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
|
||||||
|
from litellm.types.integrations.gcs_bucket import *
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
StandardCallbackDynamicParams,
|
StandardCallbackDynamicParams,
|
||||||
StandardLoggingMetadata,
|
StandardLoggingMetadata,
|
||||||
|
@ -27,12 +30,8 @@ else:
|
||||||
|
|
||||||
|
|
||||||
IAM_AUTH_KEY = "IAM_AUTH"
|
IAM_AUTH_KEY = "IAM_AUTH"
|
||||||
|
GCS_DEFAULT_BATCH_SIZE = 2048
|
||||||
|
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
|
||||||
class GCSLoggingConfig(TypedDict):
|
|
||||||
bucket_name: str
|
|
||||||
vertex_instance: VertexBase
|
|
||||||
path_service_account: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class GCSBucketLogger(GCSBucketBase):
|
class GCSBucketLogger(GCSBucketBase):
|
||||||
|
@ -41,6 +40,21 @@ class GCSBucketLogger(GCSBucketBase):
|
||||||
|
|
||||||
super().__init__(bucket_name=bucket_name)
|
super().__init__(bucket_name=bucket_name)
|
||||||
self.vertex_instances: Dict[str, VertexBase] = {}
|
self.vertex_instances: Dict[str, VertexBase] = {}
|
||||||
|
|
||||||
|
# Init Batch logging settings
|
||||||
|
self.log_queue: List[GCSLogQueueItem] = []
|
||||||
|
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
|
||||||
|
self.flush_interval = int(
|
||||||
|
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
|
||||||
|
)
|
||||||
|
asyncio.create_task(self.periodic_flush())
|
||||||
|
self.flush_lock = asyncio.Lock()
|
||||||
|
super().__init__(
|
||||||
|
flush_lock=self.flush_lock,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
flush_interval=self.flush_interval,
|
||||||
|
)
|
||||||
|
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||||
|
@ -60,44 +74,23 @@ class GCSBucketLogger(GCSBucketBase):
|
||||||
kwargs,
|
kwargs,
|
||||||
response_obj,
|
response_obj,
|
||||||
)
|
)
|
||||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
|
||||||
kwargs
|
|
||||||
)
|
|
||||||
headers = await self.construct_request_headers(
|
|
||||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
|
||||||
service_account_json=gcs_logging_config["path_service_account"],
|
|
||||||
)
|
|
||||||
bucket_name = gcs_logging_config["bucket_name"]
|
|
||||||
|
|
||||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
"standard_logging_object", None
|
"standard_logging_object", None
|
||||||
)
|
)
|
||||||
|
|
||||||
if logging_payload is None:
|
if logging_payload is None:
|
||||||
raise ValueError("standard_logging_object not found in kwargs")
|
raise ValueError("standard_logging_object not found in kwargs")
|
||||||
|
|
||||||
# Get the current date
|
# Add to logging queue - this will be flushed periodically
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
self.log_queue.append(
|
||||||
|
GCSLogQueueItem(
|
||||||
# Modify the object_name to include the date-based folder
|
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||||
object_name = f"{current_date}/{response_obj['id']}"
|
)
|
||||||
|
|
||||||
await self._log_json_data_on_gcs(
|
|
||||||
headers=headers,
|
|
||||||
bucket_name=bucket_name,
|
|
||||||
object_name=object_name,
|
|
||||||
logging_payload=logging_payload,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
from litellm.proxy.proxy_server import premium_user
|
|
||||||
|
|
||||||
if premium_user is not True:
|
|
||||||
raise ValueError(
|
|
||||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
|
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
|
||||||
|
@ -105,44 +98,77 @@ class GCSBucketLogger(GCSBucketBase):
|
||||||
response_obj,
|
response_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
|
||||||
kwargs
|
|
||||||
)
|
|
||||||
headers = await self.construct_request_headers(
|
|
||||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
|
||||||
service_account_json=gcs_logging_config["path_service_account"],
|
|
||||||
)
|
|
||||||
bucket_name = gcs_logging_config["bucket_name"]
|
|
||||||
|
|
||||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
"standard_logging_object", None
|
"standard_logging_object", None
|
||||||
)
|
)
|
||||||
|
|
||||||
if logging_payload is None:
|
if logging_payload is None:
|
||||||
raise ValueError("standard_logging_object not found in kwargs")
|
raise ValueError("standard_logging_object not found in kwargs")
|
||||||
|
|
||||||
_litellm_params = kwargs.get("litellm_params") or {}
|
# Add to logging queue - this will be flushed periodically
|
||||||
metadata = _litellm_params.get("metadata") or {}
|
self.log_queue.append(
|
||||||
|
GCSLogQueueItem(
|
||||||
# Get the current date
|
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
)
|
||||||
|
|
||||||
# Modify the object_name to include the date-based folder
|
|
||||||
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
|
|
||||||
|
|
||||||
if "gcs_log_id" in metadata:
|
|
||||||
object_name = metadata["gcs_log_id"]
|
|
||||||
|
|
||||||
await self._log_json_data_on_gcs(
|
|
||||||
headers=headers,
|
|
||||||
bucket_name=bucket_name,
|
|
||||||
object_name=object_name,
|
|
||||||
logging_payload=logging_payload,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||||
|
|
||||||
|
async def async_send_batch(self):
|
||||||
|
"""Process queued logs in batch - sends logs to GCS Bucket"""
|
||||||
|
if not self.log_queue:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
for log_item in self.log_queue:
|
||||||
|
logging_payload = log_item["payload"]
|
||||||
|
kwargs = log_item["kwargs"]
|
||||||
|
response_obj = log_item.get("response_obj", None) or {}
|
||||||
|
|
||||||
|
gcs_logging_config: GCSLoggingConfig = (
|
||||||
|
await self.get_gcs_logging_config(kwargs)
|
||||||
|
)
|
||||||
|
headers = await self.construct_request_headers(
|
||||||
|
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||||
|
service_account_json=gcs_logging_config["path_service_account"],
|
||||||
|
)
|
||||||
|
bucket_name = gcs_logging_config["bucket_name"]
|
||||||
|
object_name = self._get_object_name(
|
||||||
|
kwargs, logging_payload, response_obj
|
||||||
|
)
|
||||||
|
await self._log_json_data_on_gcs(
|
||||||
|
headers=headers,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name,
|
||||||
|
logging_payload=logging_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear the queue after processing
|
||||||
|
self.log_queue.clear()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(f"GCS Bucket batch logging error: {str(e)}")
|
||||||
|
|
||||||
|
def _get_object_name(
|
||||||
|
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the object name to use for the current payload
|
||||||
|
"""
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
if logging_payload.get("error_str", None) is not None:
|
||||||
|
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
|
||||||
|
else:
|
||||||
|
object_name = f"{current_date}/{response_obj.get('id', '')}"
|
||||||
|
|
||||||
|
# used for testing
|
||||||
|
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||||
|
_metadata = _litellm_params.get("metadata", None) or {}
|
||||||
|
if "gcs_log_id" in _metadata:
|
||||||
|
object_name = _metadata["gcs_log_id"]
|
||||||
|
|
||||||
|
return object_name
|
||||||
|
|
||||||
def _handle_folders_in_bucket_name(
|
def _handle_folders_in_bucket_name(
|
||||||
self,
|
self,
|
||||||
bucket_name: str,
|
bucket_name: str,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
httpxSpecialProvider,
|
httpxSpecialProvider,
|
||||||
|
@ -21,8 +21,8 @@ else:
|
||||||
VertexBase = Any
|
VertexBase = Any
|
||||||
|
|
||||||
|
|
||||||
class GCSBucketBase(CustomLogger):
|
class GCSBucketBase(CustomBatchLogger):
|
||||||
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
|
||||||
self.async_httpx_client = get_async_httpx_client(
|
self.async_httpx_client = get_async_httpx_client(
|
||||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||||
)
|
)
|
||||||
|
@ -30,6 +30,7 @@ class GCSBucketBase(CustomLogger):
|
||||||
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
||||||
self.path_service_account_json: Optional[str] = _path_service_account
|
self.path_service_account_json: Optional[str] = _path_service_account
|
||||||
self.BUCKET_NAME: Optional[str] = _bucket_name
|
self.BUCKET_NAME: Optional[str] = _bucket_name
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
async def construct_request_headers(
|
async def construct_request_headers(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1599,7 +1599,9 @@ async def test_key_logging(
|
||||||
details=f"Logging test failed: {str(e)}",
|
details=f"Logging test failed: {str(e)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.sleep(1) # wait for callbacks to run
|
await asyncio.sleep(
|
||||||
|
2
|
||||||
|
) # wait for callbacks to run, callbacks use batching so wait for the flush event
|
||||||
|
|
||||||
# Check if any logger exceptions were triggered
|
# Check if any logger exceptions were triggered
|
||||||
log_contents = log_capture_string.getvalue()
|
log_contents = log_capture_string.getvalue()
|
||||||
|
|
|
@ -7,10 +7,4 @@ model_list:
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["prometheus"]
|
callbacks: ["gcs_bucket"]
|
||||||
service_callback: ["prometheus_system"]
|
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
allow_requests_on_db_unavailable: true
|
|
||||||
|
|
||||||
|
|
28
litellm/types/integrations/gcs_bucket.py
Normal file
28
litellm/types/integrations/gcs_bucket.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||||
|
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
|
||||||
|
else:
|
||||||
|
VertexBase = Any
|
||||||
|
|
||||||
|
|
||||||
|
class GCSLoggingConfig(TypedDict):
|
||||||
|
"""
|
||||||
|
Internal LiteLLM Config for GCS Bucket logging
|
||||||
|
"""
|
||||||
|
|
||||||
|
bucket_name: str
|
||||||
|
vertex_instance: VertexBase
|
||||||
|
path_service_account: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class GCSLogQueueItem(TypedDict):
|
||||||
|
"""
|
||||||
|
Internal Type, used for queueing logs to be sent to GCS Bucket
|
||||||
|
"""
|
||||||
|
|
||||||
|
payload: StandardLoggingPayload
|
||||||
|
kwargs: Dict[str, Any]
|
||||||
|
response_obj: Optional[Any]
|
|
@ -28,6 +28,7 @@ verbose_logger.setLevel(logging.DEBUG)
|
||||||
def load_vertex_ai_credentials():
|
def load_vertex_ai_credentials():
|
||||||
# Define the path to the vertex_key.json file
|
# Define the path to the vertex_key.json file
|
||||||
print("loading vertex ai credentials")
|
print("loading vertex ai credentials")
|
||||||
|
os.environ["GCS_FLUSH_INTERVAL"] = "1"
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json"
|
vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json"
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue