litellm-mirror/litellm/integrations/azure_storage/azure_storage.py
Krish Dholakia 9b7ebb6a7d
build(pyproject.toml): add new dev dependencies - for type checking (#9631)
* build(pyproject.toml): add new dev dependencies - for type checking

* build: reformat files to fit black

* ci: reformat to fit black

* ci(test-litellm.yml): make tests run clear

* build(pyproject.toml): add ruff

* fix: fix ruff checks

* build(mypy/): fix mypy linting errors

* fix(hashicorp_secret_manager.py): fix passing cert for tls auth

* build(mypy/): resolve all mypy errors

* test: update test

* fix: fix black formatting

* build(pre-commit-config.yaml): use poetry run black

* fix(proxy_server.py): fix linting error

* fix: fix ruff safe representation error
2025-03-29 11:02:13 -07:00

380 lines
14 KiB
Python

import asyncio
import json
import os
import uuid
from datetime import datetime, timedelta
from typing import List, Optional
from litellm._logging import verbose_logger
from litellm.constants import AZURE_STORAGE_MSFT_VERSION
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.azure.common_utils import get_azure_ad_token_from_entrata_id
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.utils import StandardLoggingPayload
class AzureBlobStorageLogger(CustomBatchLogger):
def __init__(
self,
**kwargs,
):
try:
verbose_logger.debug(
"AzureBlobStorageLogger: in init azure blob storage logger"
)
# Env Variables used for Azure Storage Authentication
self.tenant_id = os.getenv("AZURE_STORAGE_TENANT_ID")
self.client_id = os.getenv("AZURE_STORAGE_CLIENT_ID")
self.client_secret = os.getenv("AZURE_STORAGE_CLIENT_SECRET")
self.azure_storage_account_key: Optional[str] = os.getenv(
"AZURE_STORAGE_ACCOUNT_KEY"
)
# Required Env Variables for Azure Storage
_azure_storage_account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
if not _azure_storage_account_name:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_ACCOUNT_NAME"
)
self.azure_storage_account_name: str = _azure_storage_account_name
_azure_storage_file_system = os.getenv("AZURE_STORAGE_FILE_SYSTEM")
if not _azure_storage_file_system:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_FILE_SYSTEM"
)
self.azure_storage_file_system: str = _azure_storage_file_system
# Internal variables used for Token based authentication
self.azure_auth_token: Optional[
str
] = None # the Azure AD token to use for Azure Storage API requests
self.token_expiry: Optional[
datetime
] = None # the expiry time of the currentAzure AD token
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
self.log_queue: List[StandardLoggingPayload] = []
super().__init__(**kwargs, flush_lock=self.flush_lock)
except Exception as e:
verbose_logger.exception(
f"AzureBlobStorageLogger: Got exception on init AzureBlobStorageLogger client {str(e)}"
)
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log success events to Azure Blob Storage
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
self._premium_user_check()
verbose_logger.debug(
"AzureBlobStorageLogger: Logging - Enters logging function for model %s",
kwargs,
)
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_payload is None:
raise ValueError("standard_logging_payload is not set")
self.log_queue.append(standard_logging_payload)
except Exception as e:
verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""
Async Log failure events to Azure Blob Storage
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
self._premium_user_check()
verbose_logger.debug(
"AzureBlobStorageLogger: Logging - Enters logging function for model %s",
kwargs,
)
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_payload is None:
raise ValueError("standard_logging_payload is not set")
self.log_queue.append(standard_logging_payload)
except Exception as e:
verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}")
pass
async def async_send_batch(self):
"""
Sends the in memory logs queue to Azure Blob Storage
Raises:
Raises a NON Blocking verbose_logger.exception if an error occurs
"""
try:
if not self.log_queue:
verbose_logger.exception("Datadog: log_queue does not exist")
return
verbose_logger.debug(
"AzureBlobStorageLogger - about to flush %s events",
len(self.log_queue),
)
for payload in self.log_queue:
await self.async_upload_payload_to_azure_blob_storage(payload=payload)
except Exception as e:
verbose_logger.exception(
f"AzureBlobStorageLogger Error sending batch API - {str(e)}"
)
async def async_upload_payload_to_azure_blob_storage(
self, payload: StandardLoggingPayload
):
"""
Uploads the payload to Azure Blob Storage using a 3-step process:
1. Create file resource
2. Append data
3. Flush the data
"""
try:
if self.azure_storage_account_key:
await self.upload_to_azure_data_lake_with_azure_account_key(
payload=payload
)
else:
# Get a valid token instead of always requesting a new one
await self.set_valid_azure_ad_token()
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
json_payload = (
json.dumps(payload) + "\n"
) # Add newline for each log entry
payload_bytes = json_payload.encode("utf-8")
filename = f"{payload.get('id') or str(uuid.uuid4())}.json"
base_url = f"https://{self.azure_storage_account_name}.dfs.core.windows.net/{self.azure_storage_file_system}/{filename}"
# Execute the 3-step upload process
await self._create_file(async_client, base_url)
await self._append_data(async_client, base_url, json_payload)
await self._flush_data(async_client, base_url, len(payload_bytes))
verbose_logger.debug(
f"Successfully uploaded log to Azure Blob Storage: {filename}"
)
except Exception as e:
verbose_logger.exception(f"Error uploading to Azure Blob Storage: {str(e)}")
raise e
async def _create_file(self, client: AsyncHTTPHandler, base_url: str):
"""Helper method to create the file resource"""
try:
verbose_logger.debug(f"Creating file resource at: {base_url}")
headers = {
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
"Content-Length": "0",
"Authorization": f"Bearer {self.azure_auth_token}",
}
response = await client.put(f"{base_url}?resource=file", headers=headers)
response.raise_for_status()
verbose_logger.debug("Successfully created file resource")
except Exception as e:
verbose_logger.exception(f"Error creating file resource: {str(e)}")
raise
async def _append_data(
self, client: AsyncHTTPHandler, base_url: str, json_payload: str
):
"""Helper method to append data to the file"""
try:
verbose_logger.debug(f"Appending data to file: {base_url}")
headers = {
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
"Content-Type": "application/json",
"Authorization": f"Bearer {self.azure_auth_token}",
}
response = await client.patch(
f"{base_url}?action=append&position=0",
headers=headers,
data=json_payload,
)
response.raise_for_status()
verbose_logger.debug("Successfully appended data")
except Exception as e:
verbose_logger.exception(f"Error appending data: {str(e)}")
raise
async def _flush_data(self, client: AsyncHTTPHandler, base_url: str, position: int):
"""Helper method to flush the data"""
try:
verbose_logger.debug(f"Flushing data at position {position}")
headers = {
"x-ms-version": AZURE_STORAGE_MSFT_VERSION,
"Content-Length": "0",
"Authorization": f"Bearer {self.azure_auth_token}",
}
response = await client.patch(
f"{base_url}?action=flush&position={position}", headers=headers
)
response.raise_for_status()
verbose_logger.debug("Successfully flushed data")
except Exception as e:
verbose_logger.exception(f"Error flushing data: {str(e)}")
raise
####### Helper methods to managing Authentication to Azure Storage #######
##########################################################################
async def set_valid_azure_ad_token(self):
"""
Wrapper to set self.azure_auth_token to a valid Azure AD token, refreshing if necessary
Refreshes the token when:
- Token is expired
- Token is not set
"""
# Check if token needs refresh
if self._azure_ad_token_is_expired() or self.azure_auth_token is None:
verbose_logger.debug("Azure AD token needs refresh")
self.azure_auth_token = self.get_azure_ad_token_from_azure_storage(
tenant_id=self.tenant_id,
client_id=self.client_id,
client_secret=self.client_secret,
)
# Token typically expires in 1 hour
self.token_expiry = datetime.now() + timedelta(hours=1)
verbose_logger.debug(f"New token will expire at {self.token_expiry}")
def get_azure_ad_token_from_azure_storage(
self,
tenant_id: Optional[str],
client_id: Optional[str],
client_secret: Optional[str],
) -> str:
"""
Gets Azure AD token to use for Azure Storage API requests
"""
verbose_logger.debug("Getting Azure AD Token from Azure Storage")
verbose_logger.debug(
"tenant_id %s, client_id %s, client_secret %s",
tenant_id,
client_id,
client_secret,
)
if tenant_id is None:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_TENANT_ID"
)
if client_id is None:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_CLIENT_ID"
)
if client_secret is None:
raise ValueError(
"Missing required environment variable: AZURE_STORAGE_CLIENT_SECRET"
)
token_provider = get_azure_ad_token_from_entrata_id(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
scope="https://storage.azure.com/.default",
)
token = token_provider()
verbose_logger.debug("azure auth token %s", token)
return token
def _azure_ad_token_is_expired(self):
"""
Returns True if Azure AD token is expired, False otherwise
"""
if self.azure_auth_token and self.token_expiry:
if datetime.now() + timedelta(minutes=5) >= self.token_expiry:
verbose_logger.debug("Azure AD token is expired. Requesting new token")
return True
return False
def _premium_user_check(self):
"""
Checks if the user is a premium user, raises an error if not
"""
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
if premium_user is not True:
raise ValueError(
f"AzureBlobStorageLogger is only available for premium users. {CommonProxyErrors.not_premium_user}"
)
async def upload_to_azure_data_lake_with_azure_account_key(
self, payload: StandardLoggingPayload
):
"""
Uploads the payload to Azure Data Lake using the Azure SDK
This is used when Azure Storage Account Key is set - Azure Storage Account Key does not work directly with Azure Rest API
"""
from azure.storage.filedatalake.aio import DataLakeServiceClient
# Create an async service client
service_client = DataLakeServiceClient(
account_url=f"https://{self.azure_storage_account_name}.dfs.core.windows.net",
credential=self.azure_storage_account_key,
)
# Get file system client
file_system_client = service_client.get_file_system_client(
file_system=self.azure_storage_file_system
)
try:
# Create directory with today's date
from datetime import datetime
today = datetime.now().strftime("%Y-%m-%d")
directory_client = file_system_client.get_directory_client(today)
# check if the directory exists
if not await directory_client.exists():
await directory_client.create_directory()
verbose_logger.debug(f"Created directory: {today}")
# Create a file client
file_name = f"{payload.get('id') or str(uuid.uuid4())}.json"
file_client = directory_client.get_file_client(file_name)
# Create the file
await file_client.create_file()
# Content to append
content = json.dumps(payload).encode("utf-8")
# Append content to the file
await file_client.append_data(data=content, offset=0, length=len(content))
# Flush the content to finalize the file
await file_client.flush_data(position=len(content), offset=0)
verbose_logger.debug(
f"Successfully uploaded and wrote to {today}/{file_name}"
)
except Exception as e:
verbose_logger.exception(f"Error occurred: {str(e)}")