mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
* build(pyproject.toml): add new dev dependencies - for type checking * build: reformat files to fit black * ci: reformat to fit black * ci(test-litellm.yml): make tests run clear * build(pyproject.toml): add ruff * fix: fix ruff checks * build(mypy/): fix mypy linting errors * fix(hashicorp_secret_manager.py): fix passing cert for tls auth * build(mypy/): resolve all mypy errors * test: update test * fix: fix black formatting * build(pre-commit-config.yaml): use poetry run black * fix(proxy_server.py): fix linting error * fix: fix ruff safe representation error
326 lines
12 KiB
Python
326 lines
12 KiB
Python
import json
|
|
import os
|
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
|
|
from litellm._logging import verbose_logger
|
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
get_async_httpx_client,
|
|
httpxSpecialProvider,
|
|
)
|
|
from litellm.types.integrations.gcs_bucket import *
|
|
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
|
else:
|
|
VertexBase = Any
|
|
IAM_AUTH_KEY = "IAM_AUTH"
|
|
|
|
|
|
class GCSBucketBase(CustomBatchLogger):
|
|
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
|
|
self.async_httpx_client = get_async_httpx_client(
|
|
llm_provider=httpxSpecialProvider.LoggingCallback
|
|
)
|
|
_path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
|
|
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
|
self.path_service_account_json: Optional[str] = _path_service_account
|
|
self.BUCKET_NAME: Optional[str] = _bucket_name
|
|
self.vertex_instances: Dict[str, VertexBase] = {}
|
|
super().__init__(**kwargs)
|
|
|
|
async def construct_request_headers(
|
|
self,
|
|
service_account_json: Optional[str],
|
|
vertex_instance: Optional[VertexBase] = None,
|
|
) -> Dict[str, str]:
|
|
from litellm import vertex_chat_completion
|
|
|
|
if vertex_instance is None:
|
|
vertex_instance = vertex_chat_completion
|
|
|
|
_auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
|
|
credentials=service_account_json,
|
|
project_id=None,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
|
|
auth_header, _ = vertex_instance._get_token_and_url(
|
|
model="gcs-bucket",
|
|
auth_header=_auth_header,
|
|
vertex_credentials=service_account_json,
|
|
vertex_project=vertex_project,
|
|
vertex_location=None,
|
|
gemini_api_key=None,
|
|
stream=None,
|
|
custom_llm_provider="vertex_ai",
|
|
api_base=None,
|
|
)
|
|
verbose_logger.debug("constructed auth_header %s", auth_header)
|
|
headers = {
|
|
"Authorization": f"Bearer {auth_header}", # auth_header
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
return headers
|
|
|
|
def sync_construct_request_headers(self) -> Dict[str, str]:
|
|
from litellm import vertex_chat_completion
|
|
|
|
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
|
|
credentials=self.path_service_account_json,
|
|
project_id=None,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
|
|
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
|
model="gcs-bucket",
|
|
auth_header=_auth_header,
|
|
vertex_credentials=self.path_service_account_json,
|
|
vertex_project=vertex_project,
|
|
vertex_location=None,
|
|
gemini_api_key=None,
|
|
stream=None,
|
|
custom_llm_provider="vertex_ai",
|
|
api_base=None,
|
|
)
|
|
verbose_logger.debug("constructed auth_header %s", auth_header)
|
|
headers = {
|
|
"Authorization": f"Bearer {auth_header}", # auth_header
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
return headers
|
|
|
|
def _handle_folders_in_bucket_name(
|
|
self,
|
|
bucket_name: str,
|
|
object_name: str,
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Handles when the user passes a bucket name with a folder postfix
|
|
|
|
|
|
Example:
|
|
- Bucket name: "my-bucket/my-folder/dev"
|
|
- Object name: "my-object"
|
|
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
|
|
|
|
"""
|
|
if "/" in bucket_name:
|
|
bucket_name, prefix = bucket_name.split("/", 1)
|
|
object_name = f"{prefix}/{object_name}"
|
|
return bucket_name, object_name
|
|
return bucket_name, object_name
|
|
|
|
async def get_gcs_logging_config(
|
|
self, kwargs: Optional[Dict[str, Any]] = {}
|
|
) -> GCSLoggingConfig:
|
|
"""
|
|
This function is used to get the GCS logging config for the GCS Bucket Logger.
|
|
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
|
|
If no dynamic parameters are provided, it uses the default values.
|
|
"""
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
standard_callback_dynamic_params: Optional[
|
|
StandardCallbackDynamicParams
|
|
] = kwargs.get("standard_callback_dynamic_params", None)
|
|
|
|
bucket_name: str
|
|
path_service_account: Optional[str]
|
|
if standard_callback_dynamic_params is not None:
|
|
verbose_logger.debug("Using dynamic GCS logging")
|
|
verbose_logger.debug(
|
|
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
|
|
)
|
|
|
|
_bucket_name: Optional[str] = (
|
|
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
|
or self.BUCKET_NAME
|
|
)
|
|
_path_service_account: Optional[str] = (
|
|
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
|
or self.path_service_account_json
|
|
)
|
|
|
|
if _bucket_name is None:
|
|
raise ValueError(
|
|
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
|
)
|
|
bucket_name = _bucket_name
|
|
path_service_account = _path_service_account
|
|
vertex_instance = await self.get_or_create_vertex_instance(
|
|
credentials=path_service_account
|
|
)
|
|
else:
|
|
# If no dynamic parameters, use the default instance
|
|
if self.BUCKET_NAME is None:
|
|
raise ValueError(
|
|
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
|
)
|
|
bucket_name = self.BUCKET_NAME
|
|
path_service_account = self.path_service_account_json
|
|
vertex_instance = await self.get_or_create_vertex_instance(
|
|
credentials=path_service_account
|
|
)
|
|
|
|
return GCSLoggingConfig(
|
|
bucket_name=bucket_name,
|
|
vertex_instance=vertex_instance,
|
|
path_service_account=path_service_account,
|
|
)
|
|
|
|
async def get_or_create_vertex_instance(
|
|
self, credentials: Optional[str]
|
|
) -> VertexBase:
|
|
"""
|
|
This function is used to get the Vertex instance for the GCS Bucket Logger.
|
|
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
|
|
"""
|
|
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
|
|
|
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
|
|
if _in_memory_key not in self.vertex_instances:
|
|
vertex_instance = VertexBase()
|
|
await vertex_instance._ensure_access_token_async(
|
|
credentials=credentials,
|
|
project_id=None,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
self.vertex_instances[_in_memory_key] = vertex_instance
|
|
return self.vertex_instances[_in_memory_key]
|
|
|
|
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
|
|
"""
|
|
Returns key to use for caching the Vertex instance in-memory.
|
|
|
|
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
|
|
|
|
- If a credentials string is provided, it is used as the key.
|
|
- If no credentials string is provided, "IAM_AUTH" is used as the key.
|
|
"""
|
|
return credentials or IAM_AUTH_KEY
|
|
|
|
async def download_gcs_object(self, object_name: str, **kwargs):
|
|
"""
|
|
Download an object from GCS.
|
|
|
|
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
|
"""
|
|
try:
|
|
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
|
kwargs=kwargs
|
|
)
|
|
headers = await self.construct_request_headers(
|
|
vertex_instance=gcs_logging_config["vertex_instance"],
|
|
service_account_json=gcs_logging_config["path_service_account"],
|
|
)
|
|
bucket_name = gcs_logging_config["bucket_name"]
|
|
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
|
bucket_name=bucket_name,
|
|
object_name=object_name,
|
|
)
|
|
|
|
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
|
|
|
# Send the GET request to download the object
|
|
response = await self.async_httpx_client.get(url=url, headers=headers)
|
|
|
|
if response.status_code != 200:
|
|
verbose_logger.error(
|
|
"GCS object download error: %s", str(response.text)
|
|
)
|
|
return None
|
|
|
|
verbose_logger.debug(
|
|
"GCS object download response status code: %s", response.status_code
|
|
)
|
|
|
|
# Return the content of the downloaded object
|
|
return response.content
|
|
|
|
except Exception as e:
|
|
verbose_logger.error("GCS object download error: %s", str(e))
|
|
return None
|
|
|
|
async def delete_gcs_object(self, object_name: str, **kwargs):
|
|
"""
|
|
Delete an object from GCS.
|
|
"""
|
|
try:
|
|
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
|
kwargs=kwargs
|
|
)
|
|
headers = await self.construct_request_headers(
|
|
vertex_instance=gcs_logging_config["vertex_instance"],
|
|
service_account_json=gcs_logging_config["path_service_account"],
|
|
)
|
|
bucket_name = gcs_logging_config["bucket_name"]
|
|
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
|
bucket_name=bucket_name,
|
|
object_name=object_name,
|
|
)
|
|
|
|
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
|
|
|
|
# Send the DELETE request to delete the object
|
|
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
|
|
|
if (response.status_code != 200) or (response.status_code != 204):
|
|
verbose_logger.error(
|
|
"GCS object delete error: %s, status code: %s",
|
|
str(response.text),
|
|
response.status_code,
|
|
)
|
|
return None
|
|
|
|
verbose_logger.debug(
|
|
"GCS object delete response status code: %s, response: %s",
|
|
response.status_code,
|
|
response.text,
|
|
)
|
|
|
|
# Return the content of the downloaded object
|
|
return response.text
|
|
|
|
except Exception as e:
|
|
verbose_logger.error("GCS object download error: %s", str(e))
|
|
return None
|
|
|
|
async def _log_json_data_on_gcs(
|
|
self,
|
|
headers: Dict[str, str],
|
|
bucket_name: str,
|
|
object_name: str,
|
|
logging_payload: Union[StandardLoggingPayload, str],
|
|
):
|
|
"""
|
|
Helper function to make POST request to GCS Bucket in the specified bucket.
|
|
"""
|
|
if isinstance(logging_payload, str):
|
|
json_logged_payload = logging_payload
|
|
else:
|
|
json_logged_payload = json.dumps(logging_payload, default=str)
|
|
|
|
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
|
bucket_name=bucket_name,
|
|
object_name=object_name,
|
|
)
|
|
|
|
response = await self.async_httpx_client.post(
|
|
headers=headers,
|
|
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
|
data=json_logged_payload,
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
|
|
|
verbose_logger.debug("GCS Bucket response %s", response)
|
|
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
|
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
|
|
|
return response.json()
|