forked from phoenix/litellm-mirror
add gcs bucket base
This commit is contained in:
parent
051ac50fca
commit
7d746064ab
4 changed files with 134 additions and 102 deletions
|
@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.integrations.gcs_bucket_base import GCSBucketBase
|
||||||
from litellm.litellm_core_utils.logging_utils import (
|
from litellm.litellm_core_utils.logging_utils import (
|
||||||
convert_litellm_response_object_to_dict,
|
convert_litellm_response_object_to_dict,
|
||||||
)
|
)
|
||||||
|
@ -34,26 +35,16 @@ class GCSBucketPayload(TypedDict):
|
||||||
log_event_type: Optional[str]
|
log_event_type: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class GCSBucketLogger(CustomLogger):
|
class GCSBucketLogger(GCSBucketBase):
|
||||||
def __init__(self) -> None:
|
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
||||||
from litellm.proxy.proxy_server import premium_user
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
|
super().__init__(bucket_name=bucket_name)
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.async_httpx_client = AsyncHTTPHandler(
|
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
|
||||||
)
|
|
||||||
self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None)
|
|
||||||
self.BUCKET_NAME = os.getenv("GCS_BUCKET_NAME", None)
|
|
||||||
|
|
||||||
if self.BUCKET_NAME is None:
|
|
||||||
raise ValueError(
|
|
||||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.path_service_account_json is None:
|
if self.path_service_account_json is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
|
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
|
||||||
|
@ -158,27 +149,6 @@ class GCSBucketLogger(CustomLogger):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.error("GCS Bucket logging error: %s", str(e))
|
verbose_logger.error("GCS Bucket logging error: %s", str(e))
|
||||||
|
|
||||||
async def construct_request_headers(self) -> Dict[str, str]:
|
|
||||||
from litellm import vertex_chat_completion
|
|
||||||
|
|
||||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
|
||||||
model="gcs-bucket",
|
|
||||||
vertex_credentials=self.path_service_account_json,
|
|
||||||
vertex_project=None,
|
|
||||||
vertex_location=None,
|
|
||||||
gemini_api_key=None,
|
|
||||||
stream=None,
|
|
||||||
custom_llm_provider="vertex_ai",
|
|
||||||
api_base=None,
|
|
||||||
)
|
|
||||||
verbose_logger.debug("constructed auth_header %s", auth_header)
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {auth_header}", # auth_header
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
||||||
async def get_gcs_payload(
|
async def get_gcs_payload(
|
||||||
self, kwargs, response_obj, start_time, end_time
|
self, kwargs, response_obj, start_time, end_time
|
||||||
) -> GCSBucketPayload:
|
) -> GCSBucketPayload:
|
||||||
|
@ -225,65 +195,3 @@ class GCSBucketLogger(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
return gcs_payload
|
return gcs_payload
|
||||||
|
|
||||||
async def download_gcs_object(self, object_name):
|
|
||||||
"""
|
|
||||||
Download an object from GCS.
|
|
||||||
|
|
||||||
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
headers = await self.construct_request_headers()
|
|
||||||
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}?alt=media"
|
|
||||||
|
|
||||||
# Send the GET request to download the object
|
|
||||||
response = await self.async_httpx_client.get(url=url, headers=headers)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
verbose_logger.error(
|
|
||||||
"GCS object download error: %s", str(response.text)
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
verbose_logger.debug(
|
|
||||||
"GCS object download response status code: %s", response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return the content of the downloaded object
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.error("GCS object download error: %s", str(e))
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def delete_gcs_object(self, object_name):
|
|
||||||
"""
|
|
||||||
Delete an object from GCS.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
headers = await self.construct_request_headers()
|
|
||||||
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}"
|
|
||||||
|
|
||||||
# Send the DELETE request to delete the object
|
|
||||||
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
|
||||||
|
|
||||||
if (response.status_code != 200) or (response.status_code != 204):
|
|
||||||
verbose_logger.error(
|
|
||||||
"GCS object delete error: %s, status code: %s",
|
|
||||||
str(response.text),
|
|
||||||
response.status_code,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
verbose_logger.debug(
|
|
||||||
"GCS object delete response status code: %s, response: %s",
|
|
||||||
response.status_code,
|
|
||||||
response.text,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return the content of the downloaded object
|
|
||||||
return response.text
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.error("GCS object download error: %s", str(e))
|
|
||||||
return None
|
|
||||||
|
|
115
litellm/integrations/gcs_bucket_base.py
Normal file
115
litellm/integrations/gcs_bucket_base.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.logging_utils import (
|
||||||
|
convert_litellm_response_object_to_dict,
|
||||||
|
)
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
|
|
||||||
|
class GCSBucketBase(CustomLogger):
|
||||||
|
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
||||||
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
|
self.async_httpx_client = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None)
|
||||||
|
self.BUCKET_NAME = bucket_name or os.getenv("GCS_BUCKET_NAME", None)
|
||||||
|
|
||||||
|
if self.BUCKET_NAME is None:
|
||||||
|
raise ValueError(
|
||||||
|
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def construct_request_headers(self) -> Dict[str, str]:
|
||||||
|
from litellm import vertex_chat_completion
|
||||||
|
|
||||||
|
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||||
|
model="gcs-bucket",
|
||||||
|
vertex_credentials=self.path_service_account_json,
|
||||||
|
vertex_project=None,
|
||||||
|
vertex_location=None,
|
||||||
|
gemini_api_key=None,
|
||||||
|
stream=None,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
api_base=None,
|
||||||
|
)
|
||||||
|
verbose_logger.debug("constructed auth_header %s", auth_header)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {auth_header}", # auth_header
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def download_gcs_object(self, object_name):
|
||||||
|
"""
|
||||||
|
Download an object from GCS.
|
||||||
|
|
||||||
|
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
headers = await self.construct_request_headers()
|
||||||
|
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}?alt=media"
|
||||||
|
|
||||||
|
# Send the GET request to download the object
|
||||||
|
response = await self.async_httpx_client.get(url=url, headers=headers)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
verbose_logger.error(
|
||||||
|
"GCS object download error: %s", str(response.text)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"GCS object download response status code: %s", response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the content of the downloaded object
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error("GCS object download error: %s", str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete_gcs_object(self, object_name):
|
||||||
|
"""
|
||||||
|
Delete an object from GCS.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
headers = await self.construct_request_headers()
|
||||||
|
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}"
|
||||||
|
|
||||||
|
# Send the DELETE request to delete the object
|
||||||
|
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
||||||
|
|
||||||
|
if (response.status_code != 200) or (response.status_code != 204):
|
||||||
|
verbose_logger.error(
|
||||||
|
"GCS object delete error: %s, status code: %s",
|
||||||
|
str(response.text),
|
||||||
|
response.status_code,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"GCS object delete response status code: %s, response: %s",
|
||||||
|
response.status_code,
|
||||||
|
response.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the content of the downloaded object
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error("GCS object download error: %s", str(e))
|
||||||
|
return None
|
|
@ -53,9 +53,9 @@ def get_file_contents_from_s3(bucket_name, object_key):
|
||||||
|
|
||||||
async def get_config_file_contents_from_gcs(bucket_name, object_key):
|
async def get_config_file_contents_from_gcs(bucket_name, object_key):
|
||||||
try:
|
try:
|
||||||
from litellm.integrations.gcs_bucket import GCSBucketLogger
|
from litellm.integrations.gcs_bucket_base import GCSBucketBase
|
||||||
|
|
||||||
gcs_bucket = GCSBucketLogger(
|
gcs_bucket = GCSBucketBase(
|
||||||
bucket_name=bucket_name,
|
bucket_name=bucket_name,
|
||||||
)
|
)
|
||||||
file_contents = await gcs_bucket.download_gcs_object(object_key)
|
file_contents = await gcs_bucket.download_gcs_object(object_key)
|
||||||
|
|
|
@ -163,7 +163,10 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||||
_read_request_body,
|
_read_request_body,
|
||||||
check_file_size_under_limit,
|
check_file_size_under_limit,
|
||||||
)
|
)
|
||||||
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
|
from litellm.proxy.common_utils.load_config_utils import (
|
||||||
|
get_config_file_contents_from_gcs,
|
||||||
|
get_file_contents_from_s3,
|
||||||
|
)
|
||||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||||
remove_sensitive_info_from_deployment,
|
remove_sensitive_info_from_deployment,
|
||||||
)
|
)
|
||||||
|
@ -1493,9 +1496,15 @@ class ProxyConfig:
|
||||||
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
||||||
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
|
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
|
||||||
object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
|
object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
|
||||||
|
bucket_type = os.environ.get("LITELLM_CONFIG_BUCKET_TYPE")
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"bucket_name: %s, object_key: %s", bucket_name, object_key
|
"bucket_name: %s, object_key: %s", bucket_name, object_key
|
||||||
)
|
)
|
||||||
|
if bucket_type == "gcs":
|
||||||
|
config = await get_config_file_contents_from_gcs(
|
||||||
|
bucket_name=bucket_name, object_key=object_key
|
||||||
|
)
|
||||||
|
else:
|
||||||
config = get_file_contents_from_s3(
|
config = get_file_contents_from_s3(
|
||||||
bucket_name=bucket_name, object_key=object_key
|
bucket_name=bucket_name, object_key=object_key
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue