add gcs bucket base

This commit is contained in:
Ishaan Jaff 2024-08-30 10:41:39 -07:00
parent 051ac50fca
commit 7d746064ab
4 changed files with 134 additions and 102 deletions

View file

@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.gcs_bucket_base import GCSBucketBase
from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_dict,
)
@ -34,26 +35,16 @@ class GCSBucketPayload(TypedDict):
log_event_type: Optional[str]
class GCSBucketLogger(CustomLogger):
def __init__(self) -> None:
class GCSBucketLogger(GCSBucketBase):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
super().__init__(bucket_name=bucket_name)
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
self.async_httpx_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None)
self.BUCKET_NAME = os.getenv("GCS_BUCKET_NAME", None)
if self.BUCKET_NAME is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
if self.path_service_account_json is None:
raise ValueError(
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
@ -158,27 +149,6 @@ class GCSBucketLogger(CustomLogger):
except Exception as e:
verbose_logger.error("GCS Bucket logging error: %s", str(e))
async def construct_request_headers(self) -> Dict[str, str]:
from litellm import vertex_chat_completion
auth_header, _ = vertex_chat_completion._get_token_and_url(
model="gcs-bucket",
vertex_credentials=self.path_service_account_json,
vertex_project=None,
vertex_location=None,
gemini_api_key=None,
stream=None,
custom_llm_provider="vertex_ai",
api_base=None,
)
verbose_logger.debug("constructed auth_header %s", auth_header)
headers = {
"Authorization": f"Bearer {auth_header}", # auth_header
"Content-Type": "application/json",
}
return headers
async def get_gcs_payload(
self, kwargs, response_obj, start_time, end_time
) -> GCSBucketPayload:
@ -225,65 +195,3 @@ class GCSBucketLogger(CustomLogger):
)
return gcs_payload
async def download_gcs_object(self, object_name):
"""
Download an object from GCS.
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
"""
try:
headers = await self.construct_request_headers()
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}?alt=media"
# Send the GET request to download the object
response = await self.async_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"GCS object download error: %s", str(response.text)
)
return None
verbose_logger.debug(
"GCS object download response status code: %s", response.status_code
)
# Return the content of the downloaded object
return response.content
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def delete_gcs_object(self, object_name):
"""
Delete an object from GCS.
"""
try:
headers = await self.construct_request_headers()
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}"
# Send the DELETE request to delete the object
response = await self.async_httpx_client.delete(url=url, headers=headers)
if (response.status_code != 200) or (response.status_code != 204):
verbose_logger.error(
"GCS object delete error: %s, status code: %s",
str(response.text),
response.status_code,
)
return None
verbose_logger.debug(
"GCS object delete response status code: %s, response: %s",
response.status_code,
response.text,
)
# Return the content of the downloaded object
return response.text
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None

View file

@ -0,0 +1,115 @@
import json
import os
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict, Union
import httpx
from pydantic import BaseModel, Field
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_dict,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
class GCSBucketBase(CustomLogger):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
self.async_httpx_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None)
self.BUCKET_NAME = bucket_name or os.getenv("GCS_BUCKET_NAME", None)
if self.BUCKET_NAME is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
async def construct_request_headers(self) -> Dict[str, str]:
from litellm import vertex_chat_completion
auth_header, _ = vertex_chat_completion._get_token_and_url(
model="gcs-bucket",
vertex_credentials=self.path_service_account_json,
vertex_project=None,
vertex_location=None,
gemini_api_key=None,
stream=None,
custom_llm_provider="vertex_ai",
api_base=None,
)
verbose_logger.debug("constructed auth_header %s", auth_header)
headers = {
"Authorization": f"Bearer {auth_header}", # auth_header
"Content-Type": "application/json",
}
return headers
async def download_gcs_object(self, object_name):
"""
Download an object from GCS.
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
"""
try:
headers = await self.construct_request_headers()
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}?alt=media"
# Send the GET request to download the object
response = await self.async_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"GCS object download error: %s", str(response.text)
)
return None
verbose_logger.debug(
"GCS object download response status code: %s", response.status_code
)
# Return the content of the downloaded object
return response.content
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def delete_gcs_object(self, object_name):
"""
Delete an object from GCS.
"""
try:
headers = await self.construct_request_headers()
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}"
# Send the DELETE request to delete the object
response = await self.async_httpx_client.delete(url=url, headers=headers)
if (response.status_code != 200) or (response.status_code != 204):
verbose_logger.error(
"GCS object delete error: %s, status code: %s",
str(response.text),
response.status_code,
)
return None
verbose_logger.debug(
"GCS object delete response status code: %s, response: %s",
response.status_code,
response.text,
)
# Return the content of the downloaded object
return response.text
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None

View file

@ -53,9 +53,9 @@ def get_file_contents_from_s3(bucket_name, object_key):
async def get_config_file_contents_from_gcs(bucket_name, object_key):
try:
from litellm.integrations.gcs_bucket import GCSBucketLogger
from litellm.integrations.gcs_bucket_base import GCSBucketBase
gcs_bucket = GCSBucketLogger(
gcs_bucket = GCSBucketBase(
bucket_name=bucket_name,
)
file_contents = await gcs_bucket.download_gcs_object(object_key)

View file

@ -163,7 +163,10 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
check_file_size_under_limit,
)
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
from litellm.proxy.common_utils.load_config_utils import (
get_config_file_contents_from_gcs,
get_file_contents_from_s3,
)
from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment,
)
@ -1493,9 +1496,15 @@ class ProxyConfig:
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
bucket_type = os.environ.get("LITELLM_CONFIG_BUCKET_TYPE")
verbose_proxy_logger.debug(
"bucket_name: %s, object_key: %s", bucket_name, object_key
)
if bucket_type == "gcs":
config = await get_config_file_contents_from_gcs(
bucket_name=bucket_name, object_key=object_key
)
else:
config = get_file_contents_from_s3(
bucket_name=bucket_name, object_key=object_key
)