diff --git a/litellm/integrations/gcs_bucket.py b/litellm/integrations/gcs_bucket.py index be7f8e39c2..22802797fa 100644 --- a/litellm/integrations/gcs_bucket.py +++ b/litellm/integrations/gcs_bucket.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field import litellm from litellm._logging import verbose_logger from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.gcs_bucket_base import GCSBucketBase from litellm.litellm_core_utils.logging_utils import ( convert_litellm_response_object_to_dict, ) @@ -34,26 +35,16 @@ class GCSBucketPayload(TypedDict): log_event_type: Optional[str] -class GCSBucketLogger(CustomLogger): - def __init__(self) -> None: +class GCSBucketLogger(GCSBucketBase): + def __init__(self, bucket_name: Optional[str] = None) -> None: from litellm.proxy.proxy_server import premium_user + super().__init__(bucket_name=bucket_name) if premium_user is not True: raise ValueError( f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" ) - self.async_httpx_client = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) - ) - self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None) - self.BUCKET_NAME = os.getenv("GCS_BUCKET_NAME", None) - - if self.BUCKET_NAME is None: - raise ValueError( - "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment." - ) - if self.path_service_account_json is None: raise ValueError( "GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment." @@ -158,27 +149,6 @@ class GCSBucketLogger(CustomLogger): except Exception as e: verbose_logger.error("GCS Bucket logging error: %s", str(e)) - async def construct_request_headers(self) -> Dict[str, str]: - from litellm import vertex_chat_completion - - auth_header, _ = vertex_chat_completion._get_token_and_url( - model="gcs-bucket", - vertex_credentials=self.path_service_account_json, - vertex_project=None, - vertex_location=None, - gemini_api_key=None, - stream=None, - custom_llm_provider="vertex_ai", - api_base=None, - ) - verbose_logger.debug("constructed auth_header %s", auth_header) - headers = { - "Authorization": f"Bearer {auth_header}", # auth_header - "Content-Type": "application/json", - } - - return headers - async def get_gcs_payload( self, kwargs, response_obj, start_time, end_time ) -> GCSBucketPayload: @@ -225,65 +195,3 @@ class GCSBucketLogger(CustomLogger): ) return gcs_payload - - async def download_gcs_object(self, object_name): - """ - Download an object from GCS. - - https://cloud.google.com/storage/docs/downloading-objects#download-object-json - """ - try: - headers = await self.construct_request_headers() - url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}?alt=media" - - # Send the GET request to download the object - response = await self.async_httpx_client.get(url=url, headers=headers) - - if response.status_code != 200: - verbose_logger.error( - "GCS object download error: %s", str(response.text) - ) - return None - - verbose_logger.debug( - "GCS object download response status code: %s", response.status_code - ) - - # Return the content of the downloaded object - return response.content - - except Exception as e: - verbose_logger.error("GCS object download error: %s", str(e)) - return None - - async def delete_gcs_object(self, object_name): - """ - Delete an object from GCS. - """ - try: - headers = await self.construct_request_headers() - url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}" - - # Send the DELETE request to delete the object - response = await self.async_httpx_client.delete(url=url, headers=headers) - - if (response.status_code != 200) or (response.status_code != 204): - verbose_logger.error( - "GCS object delete error: %s, status code: %s", - str(response.text), - response.status_code, - ) - return None - - verbose_logger.debug( - "GCS object delete response status code: %s, response: %s", - response.status_code, - response.text, - ) - - # Return the content of the downloaded object - return response.text - - except Exception as e: - verbose_logger.error("GCS object download error: %s", str(e)) - return None diff --git a/litellm/integrations/gcs_bucket_base.py b/litellm/integrations/gcs_bucket_base.py new file mode 100644 index 0000000000..2f34205ce3 --- /dev/null +++ b/litellm/integrations/gcs_bucket_base.py @@ -0,0 +1,115 @@ +import json +import os +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional, TypedDict, Union + +import httpx +from pydantic import BaseModel, Field + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.logging_utils import ( + convert_litellm_response_object_to_dict, +) +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + + +class GCSBucketBase(CustomLogger): + def __init__(self, bucket_name: Optional[str] = None) -> None: + from litellm.proxy.proxy_server import premium_user + + self.async_httpx_client = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None) + self.BUCKET_NAME = bucket_name or os.getenv("GCS_BUCKET_NAME", None) + + if self.BUCKET_NAME is None: + raise ValueError( + "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment." + ) + + async def construct_request_headers(self) -> Dict[str, str]: + from litellm import vertex_chat_completion + + auth_header, _ = vertex_chat_completion._get_token_and_url( + model="gcs-bucket", + vertex_credentials=self.path_service_account_json, + vertex_project=None, + vertex_location=None, + gemini_api_key=None, + stream=None, + custom_llm_provider="vertex_ai", + api_base=None, + ) + verbose_logger.debug("constructed auth_header %s", auth_header) + headers = { + "Authorization": f"Bearer {auth_header}", # auth_header + "Content-Type": "application/json", + } + + return headers + + async def download_gcs_object(self, object_name): + """ + Download an object from GCS. + + https://cloud.google.com/storage/docs/downloading-objects#download-object-json + """ + try: + headers = await self.construct_request_headers() + url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}?alt=media" + + # Send the GET request to download the object + response = await self.async_httpx_client.get(url=url, headers=headers) + + if response.status_code != 200: + verbose_logger.error( + "GCS object download error: %s", str(response.text) + ) + return None + + verbose_logger.debug( + "GCS object download response status code: %s", response.status_code + ) + + # Return the content of the downloaded object + return response.content + + except Exception as e: + verbose_logger.error("GCS object download error: %s", str(e)) + return None + + async def delete_gcs_object(self, object_name): + """ + Delete an object from GCS. + """ + try: + headers = await self.construct_request_headers() + url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}" + + # Send the DELETE request to delete the object + response = await self.async_httpx_client.delete(url=url, headers=headers) + + if (response.status_code != 200) or (response.status_code != 204): + verbose_logger.error( + "GCS object delete error: %s, status code: %s", + str(response.text), + response.status_code, + ) + return None + + verbose_logger.debug( + "GCS object delete response status code: %s, response: %s", + response.status_code, + response.text, + ) + + # Return the content of the downloaded object + return response.text + + except Exception as e: + verbose_logger.error("GCS object download error: %s", str(e)) + return None diff --git a/litellm/proxy/common_utils/load_config_utils.py b/litellm/proxy/common_utils/load_config_utils.py index 598ac835ea..93328b15e4 100644 --- a/litellm/proxy/common_utils/load_config_utils.py +++ b/litellm/proxy/common_utils/load_config_utils.py @@ -53,9 +53,9 @@ def get_file_contents_from_s3(bucket_name, object_key): async def get_config_file_contents_from_gcs(bucket_name, object_key): try: - from litellm.integrations.gcs_bucket import GCSBucketLogger + from litellm.integrations.gcs_bucket_base import GCSBucketBase - gcs_bucket = GCSBucketLogger( + gcs_bucket = GCSBucketBase( bucket_name=bucket_name, ) file_contents = await gcs_bucket.download_gcs_object(object_key) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e0dbcebfc9..82a019d946 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -163,7 +163,10 @@ from litellm.proxy.common_utils.http_parsing_utils import ( _read_request_body, check_file_size_under_limit, ) -from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3 +from litellm.proxy.common_utils.load_config_utils import ( + get_config_file_contents_from_gcs, + get_file_contents_from_s3, +) from litellm.proxy.common_utils.openai_endpoint_utils import ( remove_sensitive_info_from_deployment, ) @@ -1493,12 +1496,18 @@ class ProxyConfig: if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None: bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME") object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY") + bucket_type = os.environ.get("LITELLM_CONFIG_BUCKET_TYPE") verbose_proxy_logger.debug( "bucket_name: %s, object_key: %s", bucket_name, object_key ) - config = get_file_contents_from_s3( - bucket_name=bucket_name, object_key=object_key - ) + if bucket_type == "gcs": + config = await get_config_file_contents_from_gcs( + bucket_name=bucket_name, object_key=object_key + ) + else: + config = get_file_contents_from_s3( + bucket_name=bucket_name, object_key=object_key + ) else: # default to file config = await self.get_config(config_file_path=config_file_path)