diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 493687192f..00f6981a52 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -317,6 +317,11 @@ router_settings: | AZURE_CLIENT_SECRET | Client secret for Azure services | AZURE_FEDERATED_TOKEN_FILE | File path to Azure federated token | AZURE_KEY_VAULT_URI | URI for Azure Key Vault +| AZURE_STORAGE_ACCOUNT_NAME | Name of the Azure Storage Account to use for logging to Azure Blob Storage +| AZURE_STORAGE_FILE_SYSTEM | Name of the Azure Storage File System to use for logging to Azure Blob Storage. (Typically the Container name) +| AZURE_STORAGE_TENANT_ID | The Application Tenant ID to use for Authentication to Azure Blob Storage logging +| AZURE_STORAGE_CLIENT_ID | The Application Client ID to use for Authentication to Azure Blob Storage logging +| AZURE_STORAGE_CLIENT_SECRET | The Application Client Secret to use for Authentication to Azure Blob Storage logging | AZURE_TENANT_ID | Tenant ID for Azure Active Directory | BERRISPEND_ACCOUNT_ID | Account ID for BerriSpend service | BRAINTRUST_API_KEY | API key for Braintrust integration diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 4b9184e50c..27376e2cd4 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -4,7 +4,7 @@ Log Proxy input, output, and exceptions using: - Langfuse - OpenTelemetry -- GCS and s3 Buckets +- GCS, s3, Azure (Blob) Buckets - Custom Callbacks - Langsmith - DataDog @@ -795,7 +795,7 @@ Log LLM Logs to [Google Cloud Storage Buckets](https://cloud.google.com/storage? ```yaml model_list: - litellm_params: - api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ + api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_key: my-fake-key model: openai/my-fake-model model_name: fake-openai-endpoint @@ -841,7 +841,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ #### Fields Logged on GCS Buckets -[**The standard logging object is logged on GCS Bucket**](../proxy/logging) +[**The standard logging object is logged on GCS Bucket**](../proxy/logging_spec) #### Getting `service_account.json` from Google Cloud Console @@ -914,6 +914,83 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ Your logs should be available on the specified s3 Bucket +## Azure Blob Storage + +Log LLM Logs to [Azure Data Lake Storage](https://learn.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-introduction) + +:::info + +✨ This is an Enterprise only feature [Get Started with Enterprise here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) + +::: + + +| Property | Details | +|----------|---------| +| Description | Log LLM Input/Output to Azure Blob Storag (Bucket) | +| Azure Docs on Data Lake Storage | [Azure Data Lake Storage](https://learn.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-introduction) | + + + +#### Usage + +1. Add `azure_storage` to LiteLLM Config.yaml +```yaml +model_list: + - model_name: fake-openai-endpoint + litellm_params: + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + +litellm_settings: + callbacks: ["azure_storage"] # 👈 KEY CHANGE # 👈 KEY CHANGE +``` + +2. Set required env variables + +```shell +AZURE_STORAGE_ACCOUNT_NAME="litellm2" # The name of the Azure Storage Account to use for logging +AZURE_STORAGE_FILE_SYSTEM="litellm-logs" # The name of the Azure Storage File System to use for logging. (Typically the Container name) +AZURE_STORAGE_TENANT_ID="985efd7cxxxxxxxxxx" # The Application Tenant ID to use for Authentication +AZURE_STORAGE_CLIENT_ID="abe66585xxxxxxxxxx" # The Application Client ID to use for Authentication +AZURE_STORAGE_CLIENT_SECRET="uMS8Qxxxxxxxxxx" # The Application Client Secret to use for Authentication +``` + +3. Start Proxy + +``` +litellm --config /path/to/config.yaml +``` + +4. Test it! + +```bash +curl --location 'http://0.0.0.0:4000/chat/completions' \ +--header 'Content-Type: application/json' \ +--data ' { + "model": "fake-openai-endpoint", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + } +' +``` + + +#### Expected Logs on Azure Data Lake Storage + + + +#### Fields Logged on Azure Data Lake Storage + +[**The standard logging object is logged on Azure Data Lake Storage**](../proxy/logging_spec) + + + ## DataDog LiteLLM Supports logging to the following Datdog Integrations: diff --git a/docs/my-website/img/azure_blob.png b/docs/my-website/img/azure_blob.png new file mode 100644 index 0000000000..750fe22577 Binary files /dev/null and b/docs/my-website/img/azure_blob.png differ diff --git a/litellm/__init__.py b/litellm/__init__.py index 2cbc4efa05..70e2412a95 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -65,6 +65,7 @@ _custom_logger_compatible_callbacks_literal = Literal[ "arize", "langtrace", "gcs_bucket", + "azure_storage", "opik", "argilla", "mlflow", diff --git a/litellm/constants.py b/litellm/constants.py index 0c1b4a73d9..de745c63b8 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -71,6 +71,9 @@ LITELLM_CHAT_PROVIDERS = [ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when converting response format to tool call +########################### Logging Callback Constants ########################### +AZURE_STORAGE_MSFT_VERSION = "2019-07-07" + ########################### LiteLLM Proxy Specific Constants ########################### MAX_SPENDLOG_ROWS_TO_QUERY = ( 1_000_000 # if spendLogs has more than 1M rows, do not query the DB diff --git a/litellm/integrations/azure_storage/azure_storage.py b/litellm/integrations/azure_storage/azure_storage.py new file mode 100644 index 0000000000..0982268b44 --- /dev/null +++ b/litellm/integrations/azure_storage/azure_storage.py @@ -0,0 +1,333 @@ +import asyncio +import json +import os +import uuid +from datetime import datetime, timedelta +from re import S, T +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + TypedDict, + Union, +) + +import httpx +from pydantic import BaseModel, Field + +import litellm +from litellm._logging import verbose_logger +from litellm.constants import AZURE_STORAGE_MSFT_VERSION +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.azure.common_utils import get_azure_ad_token_from_entrata_id +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.utils import StandardLoggingPayload + + +class AzureBlobStorageLogger(CustomBatchLogger): + def __init__( + self, + **kwargs, + ): + try: + verbose_logger.debug( + "AzureBlobStorageLogger: in init azure blob storage logger" + ) + # check if the correct env variables are set + _tenant_id = os.getenv("AZURE_STORAGE_TENANT_ID") + if not _tenant_id: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_TENANT_ID" + ) + self.tenant_id: str = _tenant_id + + _client_id = os.getenv("AZURE_STORAGE_CLIENT_ID") + if not _client_id: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_CLIENT_ID" + ) + self.client_id: str = _client_id + + _client_secret = os.getenv("AZURE_STORAGE_CLIENT_SECRET") + if not _client_secret: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_CLIENT_SECRET" + ) + self.client_secret: str = _client_secret + + _azure_storage_account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME") + if not _azure_storage_account_name: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_ACCOUNT_NAME" + ) + self.azure_storage_account_name: str = _azure_storage_account_name + + _azure_storage_file_system = os.getenv("AZURE_STORAGE_FILE_SYSTEM") + if not _azure_storage_file_system: + raise ValueError( + "Missing required environment variable: AZURE_STORAGE_FILE_SYSTEM" + ) + self.azure_storage_file_system: str = _azure_storage_file_system + + self.azure_auth_token: Optional[str] = ( + None # the Azure AD token to use for Azure Storage API requests + ) + self.token_expiry: Optional[datetime] = ( + None # the expiry time of the currentAzure AD token + ) + + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + self.log_queue: List[StandardLoggingPayload] = [] + super().__init__(**kwargs, flush_lock=self.flush_lock) + except Exception as e: + verbose_logger.exception( + f"AzureBlobStorageLogger: Got exception on init AzureBlobStorageLogger client {str(e)}" + ) + raise e + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Async Log success events to Azure Blob Storage + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + try: + self._premium_user_check() + verbose_logger.debug( + "AzureBlobStorageLogger: Logging - Enters logging function for model %s", + kwargs, + ) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + + if standard_logging_payload is None: + raise ValueError("standard_logging_payload is not set") + + self.log_queue.append(standard_logging_payload) + + except Exception as e: + verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}") + pass + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + """ + Async Log failure events to Azure Blob Storage + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + try: + self._premium_user_check() + verbose_logger.debug( + "AzureBlobStorageLogger: Logging - Enters logging function for model %s", + kwargs, + ) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + + if standard_logging_payload is None: + raise ValueError("standard_logging_payload is not set") + + self.log_queue.append(standard_logging_payload) + except Exception as e: + verbose_logger.exception(f"AzureBlobStorageLogger Layer Error - {str(e)}") + pass + + async def async_send_batch(self): + """ + Sends the in memory logs queue to Azure Blob Storage + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + try: + if not self.log_queue: + verbose_logger.exception("Datadog: log_queue does not exist") + return + + verbose_logger.debug( + "AzureBlobStorageLogger - about to flush %s events", + len(self.log_queue), + ) + + # Get a valid token instead of always requesting a new one + await self.set_valid_azure_ad_token() + + for payload in self.log_queue: + await self.async_upload_payload_to_azure_blob_storage(payload=payload) + + except Exception as e: + verbose_logger.exception( + f"AzureBlobStorageLogger Error sending batch API - {str(e)}" + ) + + async def async_upload_payload_to_azure_blob_storage( + self, payload: StandardLoggingPayload + ): + """ + Uploads the payload to Azure Blob Storage using a 3-step process: + 1. Create file resource + 2. Append data + 3. Flush the data + """ + try: + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + json_payload = json.dumps(payload) + "\n" # Add newline for each log entry + payload_bytes = json_payload.encode("utf-8") + filename = f"{payload.get('id') or str(uuid.uuid4())}.json" + base_url = f"https://{self.azure_storage_account_name}.dfs.core.windows.net/{self.azure_storage_file_system}/{filename}" + + # Execute the 3-step upload process + await self._create_file(async_client, base_url) + await self._append_data(async_client, base_url, json_payload) + await self._flush_data(async_client, base_url, len(payload_bytes)) + + verbose_logger.debug( + f"Successfully uploaded log to Azure Blob Storage: {filename}" + ) + + except Exception as e: + verbose_logger.exception(f"Error uploading to Azure Blob Storage: {str(e)}") + raise e + + async def _create_file(self, client: AsyncHTTPHandler, base_url: str): + """Helper method to create the file resource""" + try: + verbose_logger.debug(f"Creating file resource at: {base_url}") + headers = { + "x-ms-version": AZURE_STORAGE_MSFT_VERSION, + "Content-Length": "0", + "Authorization": f"Bearer {self.azure_auth_token}", + } + response = await client.put(f"{base_url}?resource=file", headers=headers) + response.raise_for_status() + verbose_logger.debug("Successfully created file resource") + except Exception as e: + verbose_logger.exception(f"Error creating file resource: {str(e)}") + raise + + async def _append_data( + self, client: AsyncHTTPHandler, base_url: str, json_payload: str + ): + """Helper method to append data to the file""" + try: + verbose_logger.debug(f"Appending data to file: {base_url}") + headers = { + "x-ms-version": AZURE_STORAGE_MSFT_VERSION, + "Content-Type": "application/json", + "Authorization": f"Bearer {self.azure_auth_token}", + } + response = await client.patch( + f"{base_url}?action=append&position=0", + headers=headers, + data=json_payload, + ) + response.raise_for_status() + verbose_logger.debug("Successfully appended data") + except Exception as e: + verbose_logger.exception(f"Error appending data: {str(e)}") + raise + + async def _flush_data(self, client: AsyncHTTPHandler, base_url: str, position: int): + """Helper method to flush the data""" + try: + verbose_logger.debug(f"Flushing data at position {position}") + headers = { + "x-ms-version": AZURE_STORAGE_MSFT_VERSION, + "Content-Length": "0", + "Authorization": f"Bearer {self.azure_auth_token}", + } + response = await client.patch( + f"{base_url}?action=flush&position={position}", headers=headers + ) + response.raise_for_status() + verbose_logger.debug("Successfully flushed data") + except Exception as e: + verbose_logger.exception(f"Error flushing data: {str(e)}") + raise + + ####### Helper methods to managing Authentication to Azure Storage ####### + ########################################################################## + + async def set_valid_azure_ad_token(self): + """ + Wrapper to set self.azure_auth_token to a valid Azure AD token, refreshing if necessary + + Refreshes the token when: + - Token is expired + - Token is not set + """ + # Check if token needs refresh + if self._azure_ad_token_is_expired() or self.azure_auth_token is None: + verbose_logger.debug("Azure AD token needs refresh") + self.azure_auth_token = self.get_azure_ad_token_from_azure_storage( + tenant_id=self.tenant_id, + client_id=self.client_id, + client_secret=self.client_secret, + ) + # Token typically expires in 1 hour + self.token_expiry = datetime.now() + timedelta(hours=1) + verbose_logger.debug(f"New token will expire at {self.token_expiry}") + + def get_azure_ad_token_from_azure_storage( + self, + tenant_id: str, + client_id: str, + client_secret: str, + ) -> str: + """ + Gets Azure AD token to use for Azure Storage API requests + """ + verbose_logger.debug("Getting Azure AD Token from Azure Storage") + verbose_logger.debug( + "tenant_id %s, client_id %s, client_secret %s", + tenant_id, + client_id, + client_secret, + ) + token_provider = get_azure_ad_token_from_entrata_id( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + scope="https://storage.azure.com/.default", + ) + token = token_provider() + + verbose_logger.debug("azure auth token %s", token) + + return token + + def _azure_ad_token_is_expired(self): + """ + Returns True if Azure AD token is expired, False otherwise + """ + if self.azure_auth_token and self.token_expiry: + if datetime.now() + timedelta(minutes=5) >= self.token_expiry: + verbose_logger.debug("Azure AD token is expired. Requesting new token") + return True + return False + + def _premium_user_check(self): + """ + Checks if the user is a premium user, raises an error if not + """ + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + if premium_user is not True: + raise ValueError( + f"AzureBlobStorageLogger is only available for premium users. {CommonProxyErrors.not_premium_user}" + ) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 23ebb6ccd5..f14307219e 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -67,6 +67,7 @@ from litellm.utils import ( from ..integrations.argilla import ArgillaLogger from ..integrations.arize_ai import ArizeLogger from ..integrations.athina import AthinaLogger +from ..integrations.azure_storage.azure_storage import AzureBlobStorageLogger from ..integrations.braintrust_logging import BraintrustLogger from ..integrations.datadog.datadog import DataDogLogger from ..integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger @@ -2226,6 +2227,14 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _gcs_bucket_logger = GCSBucketLogger() _in_memory_loggers.append(_gcs_bucket_logger) return _gcs_bucket_logger # type: ignore + elif logging_integration == "azure_storage": + for callback in _in_memory_loggers: + if isinstance(callback, AzureBlobStorageLogger): + return callback # type: ignore + + _azure_storage_logger = AzureBlobStorageLogger() + _in_memory_loggers.append(_azure_storage_logger) + return _azure_storage_logger # type: ignore elif logging_integration == "opik": for callback in _in_memory_loggers: if isinstance(callback, OpikLogger): @@ -2410,6 +2419,10 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 for callback in _in_memory_loggers: if isinstance(callback, GCSBucketLogger): return callback + elif logging_integration == "azure_storage": + for callback in _in_memory_loggers: + if isinstance(callback, AzureBlobStorageLogger): + return callback elif logging_integration == "opik": for callback in _in_memory_loggers: if isinstance(callback, OpikLogger): diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index b5033295c4..f54a5499c0 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -1,8 +1,10 @@ -from typing import Optional, Union +from typing import Callable, Optional, Union import httpx +from litellm._logging import verbose_logger from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.secret_managers.main import get_secret_str class AzureOpenAIError(BaseLLMException): @@ -44,3 +46,63 @@ def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict: } return {**llm_response_headers, **openai_headers} + + +def get_azure_ad_token_from_entrata_id( + tenant_id: str, + client_id: str, + client_secret: str, + scope: str = "https://cognitiveservices.azure.com/.default", +) -> Callable[[], str]: + """ + Get Azure AD token provider from `client_id`, `client_secret`, and `tenant_id` + + Args: + tenant_id: str + client_id: str + client_secret: str + scope: str + + Returns: + callable that returns a bearer token. + """ + from azure.identity import ( + ClientSecretCredential, + DefaultAzureCredential, + get_bearer_token_provider, + ) + + verbose_logger.debug("Getting Azure AD Token from Entrata ID") + + if tenant_id.startswith("os.environ/"): + _tenant_id = get_secret_str(tenant_id) + else: + _tenant_id = tenant_id + + if client_id.startswith("os.environ/"): + _client_id = get_secret_str(client_id) + else: + _client_id = client_id + + if client_secret.startswith("os.environ/"): + _client_secret = get_secret_str(client_secret) + else: + _client_secret = client_secret + + verbose_logger.debug( + "tenant_id %s, client_id %s, client_secret %s", + _tenant_id, + _client_id, + _client_secret, + ) + if _tenant_id is None or _client_id is None or _client_secret is None: + raise ValueError("tenant_id, client_id, and client_secret must be provided") + credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret) + + verbose_logger.debug("credential %s", credential) + + token_provider = get_bearer_token_provider(credential, scope) + + verbose_logger.debug("token_provider %s", token_provider) + + return token_provider diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index d08bc794fc..aa91662918 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -284,6 +284,66 @@ class AsyncHTTPHandler: except Exception as e: raise e + async def patch( + self, + url: str, + data: Optional[Union[dict, str]] = None, # type: ignore + json: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + stream: bool = False, + ): + try: + if timeout is None: + timeout = self.timeout + + req = self.client.build_request( + "PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + response = await self.client.send(req) + response.raise_for_status() + return response + except (httpx.RemoteProtocolError, httpx.ConnectError): + # Retry the request with a new session if there is a connection error + new_client = self.create_client( + timeout=timeout, concurrent_limit=1, event_hooks=self.event_hooks + ) + try: + return await self.single_connection_post_request( + url=url, + client=new_client, + data=data, + json=json, + params=params, + headers=headers, + stream=stream, + ) + finally: + await new_client.aclose() + except httpx.TimeoutException as e: + headers = {} + error_response = getattr(e, "response", None) + if error_response is not None: + for key, value in error_response.headers.items(): + headers["response_headers-{}".format(key)] = value + + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model="default-model-name", + llm_provider="litellm-httpx-handler", + headers=headers, + ) + except httpx.HTTPStatusError as e: + setattr(e, "status_code", e.response.status_code) + if stream is True: + setattr(e, "message", await e.response.aread()) + else: + setattr(e, "message", e.response.text) + raise e + except Exception as e: + raise e + async def delete( self, url: str, @@ -472,6 +532,51 @@ class HTTPHandler: except Exception as e: raise e + def patch( + self, + url: str, + data: Optional[Union[dict, str]] = None, + json: Optional[Union[dict, str]] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + stream: bool = False, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ): + try: + + if timeout is not None: + req = self.client.build_request( + "PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + else: + req = self.client.build_request( + "PATCH", url, data=data, json=json, params=params, headers=headers # type: ignore + ) + response = self.client.send(req, stream=stream) + response.raise_for_status() + return response + except httpx.TimeoutException: + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model="default-model-name", + llm_provider="litellm-httpx-handler", + ) + except httpx.HTTPStatusError as e: + + if stream is True: + setattr(e, "message", mask_sensitive_info(e.response.read())) + setattr(e, "text", mask_sensitive_info(e.response.read())) + else: + error_text = mask_sensitive_info(e.response.text) + setattr(e, "message", error_text) + setattr(e, "text", error_text) + + setattr(e, "status_code", e.response.status_code) + + raise e + except Exception as e: + raise e + def put( self, url: str, diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 8a11edce8f..70b8c71fdc 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -10,6 +10,7 @@ import litellm from litellm import get_secret, get_secret_str from litellm._logging import verbose_router_logger from litellm.llms.azure.azure import get_azure_ad_token_from_oidc +from litellm.llms.azure.common_utils import get_azure_ad_token_from_entrata_id from litellm.secret_managers.get_azure_ad_token_provider import ( get_azure_ad_token_provider, ) @@ -196,12 +197,10 @@ class InitalizeOpenAISDKClient: verbose_router_logger.debug( "Using Azure AD Token Provider for Azure Auth" ) - azure_ad_token_provider = ( - InitalizeOpenAISDKClient.get_azure_ad_token_from_entrata_id( - tenant_id=litellm_params.get("tenant_id"), - client_id=litellm_params.get("client_id"), - client_secret=litellm_params.get("client_secret"), - ) + azure_ad_token_provider = get_azure_ad_token_from_entrata_id( + tenant_id=litellm_params.get("tenant_id"), + client_id=litellm_params.get("client_id"), + client_secret=litellm_params.get("client_secret"), ) if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": @@ -550,50 +549,3 @@ class InitalizeOpenAISDKClient: ttl=client_ttl, local_only=True, ) # cache for 1 hr - - @staticmethod - def get_azure_ad_token_from_entrata_id( - tenant_id: str, client_id: str, client_secret: str - ) -> Callable[[], str]: - from azure.identity import ( - ClientSecretCredential, - DefaultAzureCredential, - get_bearer_token_provider, - ) - - verbose_router_logger.debug("Getting Azure AD Token from Entrata ID") - - if tenant_id.startswith("os.environ/"): - _tenant_id = get_secret_str(tenant_id) - else: - _tenant_id = tenant_id - - if client_id.startswith("os.environ/"): - _client_id = get_secret_str(client_id) - else: - _client_id = client_id - - if client_secret.startswith("os.environ/"): - _client_secret = get_secret_str(client_secret) - else: - _client_secret = client_secret - - verbose_router_logger.debug( - "tenant_id %s, client_id %s, client_secret %s", - _tenant_id, - _client_id, - _client_secret, - ) - if _tenant_id is None or _client_id is None or _client_secret is None: - raise ValueError("tenant_id, client_id, and client_secret must be provided") - credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret) - - verbose_router_logger.debug("credential %s", credential) - - token_provider = get_bearer_token_provider( - credential, "https://cognitiveservices.azure.com/.default" - ) - - verbose_router_logger.debug("token_provider %s", token_provider) - - return token_provider diff --git a/tests/logging_callback_tests/test_azure_blob_storage.py b/tests/logging_callback_tests/test_azure_blob_storage.py new file mode 100644 index 0000000000..a90f253cc9 --- /dev/null +++ b/tests/logging_callback_tests/test_azure_blob_storage.py @@ -0,0 +1,45 @@ +import io +import os +import sys + + +sys.path.insert(0, os.path.abspath("../..")) + +import asyncio +import gzip +import json +import logging +import time +from unittest.mock import AsyncMock, patch + +import pytest + +import litellm +from litellm import completion +from litellm._logging import verbose_logger +from litellm.integrations.datadog.datadog import * +from datetime import datetime, timedelta +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingModelInformation, + StandardLoggingMetadata, + StandardLoggingHiddenParams, +) +from litellm.integrations.azure_storage.azure_storage import AzureBlobStorageLogger + +verbose_logger.setLevel(logging.DEBUG) + + +@pytest.mark.asyncio +async def test_azure_blob_storage(): + azure_storage_logger = AzureBlobStorageLogger(flush_interval=1) + litellm.callbacks = [azure_storage_logger] + + response = await litellm.acompletion( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello, world!"}], + ) + print(response) + + await asyncio.sleep(3) + pass diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index 7150c7bf8b..8e994631a5 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -36,6 +36,7 @@ from litellm.integrations.argilla import ArgillaLogger from litellm.integrations.langfuse.langfuse_prompt_management import ( LangfusePromptManagement, ) +from litellm.integrations.azure_storage.azure_storage import AzureBlobStorageLogger from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler from unittest.mock import patch @@ -59,6 +60,7 @@ callback_class_str_to_classType = { "opik": OpikLogger, "argilla": ArgillaLogger, "opentelemetry": OpenTelemetry, + "azure_storage": AzureBlobStorageLogger, # OTEL compatible loggers "logfire": OpenTelemetry, "arize": OpenTelemetry,