litellm/litellm/integrations/gcs_bucket.py
Ishaan Jaff 21e05a0f3e
(feat proxy) add key based logging for GCS bucket (#6031)
* init litellm langfuse / gcs credentials in litellm logging obj

* add gcs key based test

* rename vars

* save standard_callback_dynamic_params in model call details

* add working gcs bucket key based logging

* test_basic_gcs_logging_per_request

* linting fix

* add doc on gcs  bucket team based logging
2024-10-03 15:24:31 +05:30

306 lines
12 KiB
Python

import json
import os
import uuid
from datetime import datetime
from re import S
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union
import httpx
from pydantic import BaseModel, Field
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.gcs_bucket_base import GCSBucketBase
from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_dict,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
from litellm.types.utils import (
StandardCallbackDynamicParams,
StandardLoggingMetadata,
StandardLoggingPayload,
)
if TYPE_CHECKING:
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
else:
VertexBase = Any
class GCSLoggingConfig(TypedDict):
bucket_name: str
vertex_instance: VertexBase
path_service_account: str
class GCSBucketLogger(GCSBucketBase):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
super().__init__(bucket_name=bucket_name)
self.vertex_instances: Dict[str, VertexBase] = {}
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
if self.path_service_account_json is None:
raise ValueError(
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
)
pass
#### ASYNC ####
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
try:
verbose_logger.debug(
"GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
json_logged_payload = json.dumps(logging_payload)
# Get the current date
current_date = datetime.now().strftime("%Y-%m-%d")
# Modify the object_name to include the date-based folder
object_name = f"{current_date}/{response_obj['id']}"
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
try:
verbose_logger.debug(
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
_litellm_params = kwargs.get("litellm_params") or {}
metadata = _litellm_params.get("metadata") or {}
json_logged_payload = json.dumps(logging_payload)
# Get the current date
current_date = datetime.now().strftime("%Y-%m-%d")
# Modify the object_name to include the date-based folder
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
if "gcs_log_id" in metadata:
object_name = metadata["gcs_log_id"]
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def get_gcs_logging_config(
self, kwargs: Optional[Dict[str, Any]] = {}
) -> GCSLoggingConfig:
"""
This function is used to get the GCS logging config for the GCS Bucket Logger.
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
If no dynamic parameters are provided, it uses the default values.
"""
if kwargs is None:
kwargs = {}
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
if standard_callback_dynamic_params is not None:
verbose_logger.debug("Using dynamic GCS logging")
verbose_logger.debug(
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
)
bucket_name: str = (
standard_callback_dynamic_params.get("gcs_bucket_name", None)
or self.BUCKET_NAME
)
path_service_account: str = (
standard_callback_dynamic_params.get("gcs_path_service_account", None)
or self.path_service_account_json
)
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
else:
# If no dynamic parameters, use the default instance
bucket_name = self.BUCKET_NAME
path_service_account = self.path_service_account_json
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
return GCSLoggingConfig(
bucket_name=bucket_name,
vertex_instance=vertex_instance,
path_service_account=path_service_account,
)
async def get_or_create_vertex_instance(self, credentials: str) -> VertexBase:
"""
This function is used to get the Vertex instance for the GCS Bucket Logger.
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
"""
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import (
VertexBase,
)
if credentials not in self.vertex_instances:
vertex_instance = VertexBase()
await vertex_instance._ensure_access_token_async(
credentials=credentials,
project_id=None,
custom_llm_provider="vertex_ai",
)
self.vertex_instances[credentials] = vertex_instance
return self.vertex_instances[credentials]
async def download_gcs_object(self, object_name: str, **kwargs):
"""
Download an object from GCS.
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
# Send the GET request to download the object
response = await self.async_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"GCS object download error: %s", str(response.text)
)
return None
verbose_logger.debug(
"GCS object download response status code: %s", response.status_code
)
# Return the content of the downloaded object
return response.content
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def delete_gcs_object(self, object_name: str, **kwargs):
"""
Delete an object from GCS.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
# Send the DELETE request to delete the object
response = await self.async_httpx_client.delete(url=url, headers=headers)
if (response.status_code != 200) or (response.status_code != 204):
verbose_logger.error(
"GCS object delete error: %s, status code: %s",
str(response.text),
response.status_code,
)
return None
verbose_logger.debug(
"GCS object delete response status code: %s, response: %s",
response.status_code,
response.text,
)
# Return the content of the downloaded object
return response.text
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None