feat(exceptions): implement sensitive data masking in exception messages

This commit is contained in:
Cole McIntosh 2025-03-15 17:04:09 -06:00
parent f505716499
commit c7b2596965
2 changed files with 140 additions and 26 deletions

View file

@ -15,7 +15,62 @@ import httpx
import openai
from litellm.types.utils import LiteLLMCommonStrings
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
# Initialize a single SensitiveDataMasker instance to be used across all exception classes
_sensitive_data_masker = SensitiveDataMasker()
def _mask_message(message):
"""Helper function to mask sensitive data in exception messages"""
if not message:
return message
# Directly process the message string to mask sensitive patterns
import re
# Common API key patterns (sk-, pk-, api-, etc.)
patterns = [
# OpenAI and similar keys
r'sk-[a-zA-Z0-9]{10,}',
r'sk_[a-zA-Z0-9]{10,}',
# AWS keys
r'AKIA[0-9A-Z]{16}',
r'[a-zA-Z0-9+/]{40}', # AWS secret key pattern
# Azure keys
r'[a-zA-Z0-9]{32}',
# Database connection strings
r'(mongodb(\+srv)?:\/\/)[^:]+:[^@]+@[^\/]+',
# API tokens and keys
r'key-[a-zA-Z0-9]{24,}',
r'token-[a-zA-Z0-9]{24,}',
# Named keys and secrets
r'secret_[a-zA-Z0-9]{5,}',
r'pass[a-zA-Z0-9]{3,}word',
# Generic patterns with capture groups
r'(API key[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
r'(api[_-]?key[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
r'(secret[_-]?key[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
r'(access[_-]?key[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
r'(password[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
r'(token[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
]
# Apply masking
masked_message = message
for pattern in patterns:
if '(' in pattern and ')' in pattern: # Has capturing groups
# For patterns with capturing groups, keep the prefix and mask the value
def replace_func(match):
if len(match.groups()) > 1:
return match.group(1) + _sensitive_data_masker._mask_value(match.group(2))
return _sensitive_data_masker._mask_value(match.group(0))
masked_message = re.sub(pattern, replace_func, masked_message)
else:
# For patterns without capturing groups, mask the entire match
masked_message = re.sub(pattern, lambda m: _sensitive_data_masker._mask_value(m.group(0)), masked_message)
return masked_message
class AuthenticationError(openai.AuthenticationError): # type: ignore
def __init__(
@ -29,7 +84,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
num_retries: Optional[int] = None,
):
self.status_code = 401
self.message = "litellm.AuthenticationError: {}".format(message)
self.message = _mask_message("litellm.AuthenticationError: {}".format(message))
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
@ -75,7 +130,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore
num_retries: Optional[int] = None,
):
self.status_code = 404
self.message = "litellm.NotFoundError: {}".format(message)
self.message = _mask_message("litellm.NotFoundError: {}".format(message))
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
@ -121,7 +176,7 @@ class BadRequestError(openai.BadRequestError): # type: ignore
body: Optional[dict] = None,
):
self.status_code = 400
self.message = "litellm.BadRequestError: {}".format(message)
self.message = _mask_message("litellm.BadRequestError: {}".format(message))
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
@ -166,7 +221,7 @@ class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
num_retries: Optional[int] = None,
):
self.status_code = 422
self.message = "litellm.UnprocessableEntityError: {}".format(message)
self.message = _mask_message("litellm.UnprocessableEntityError: {}".format(message))
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
@ -212,7 +267,7 @@ class Timeout(openai.APITimeoutError): # type: ignore
request=request
) # Call the base class constructor with the parameters it needs
self.status_code = 408
self.message = "litellm.Timeout: {}".format(message)
self.message = _mask_message("litellm.Timeout: {}".format(message))
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
@ -250,7 +305,7 @@ class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
num_retries: Optional[int] = None,
):
self.status_code = 403
self.message = "litellm.PermissionDeniedError: {}".format(message)
self.message = _mask_message("litellm.PermissionDeniedError: {}".format(message))
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
@ -289,7 +344,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore
num_retries: Optional[int] = None,
):
self.status_code = 429
self.message = "litellm.RateLimitError: {}".format(message)
self.message = _mask_message("litellm.RateLimitError: {}".format(message))
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
@ -354,7 +409,7 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs
# set after, to make it clear the raised error is a context window exceeded error
self.message = "litellm.ContextWindowExceededError: {}".format(self.message)
self.message = _mask_message("litellm.ContextWindowExceededError: {}".format(self.message))
def __str__(self):
_message = self.message
@ -384,7 +439,7 @@ class RejectedRequestError(BadRequestError): # type: ignore
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = "litellm.RejectedRequestError: {}".format(message)
self.message = _mask_message("litellm.RejectedRequestError: {}".format(message))
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
@ -427,7 +482,7 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = "litellm.ContentPolicyViolationError: {}".format(message)
self.message = _mask_message("litellm.ContentPolicyViolationError: {}".format(message))
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
@ -470,7 +525,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
num_retries: Optional[int] = None,
):
self.status_code = 503
self.message = "litellm.ServiceUnavailableError: {}".format(message)
self.message = _mask_message("litellm.ServiceUnavailableError: {}".format(message))
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
@ -516,7 +571,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore
num_retries: Optional[int] = None,
):
self.status_code = 500
self.message = "litellm.InternalServerError: {}".format(message)
self.message = _mask_message("litellm.InternalServerError: {}".format(message))
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
@ -564,7 +619,7 @@ class APIError(openai.APIError): # type: ignore
num_retries: Optional[int] = None,
):
self.status_code = status_code
self.message = "litellm.APIError: {}".format(message)
self.message = _mask_message("litellm.APIError: {}".format(message))
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
@ -603,7 +658,7 @@ class APIConnectionError(openai.APIConnectionError): # type: ignore
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.message = "litellm.APIConnectionError: {}".format(message)
self.message = _mask_message("litellm.APIConnectionError: {}".format(message))
self.llm_provider = llm_provider
self.model = model
self.status_code = 500
@ -641,7 +696,7 @@ class APIResponseValidationError(openai.APIResponseValidationError): # type: ig
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.message = "litellm.APIResponseValidationError: {}".format(message)
self.message = _mask_message("litellm.APIResponseValidationError: {}".format(message))
self.llm_provider = llm_provider
self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
@ -675,9 +730,9 @@ class JSONSchemaValidationError(APIResponseValidationError):
self.raw_response = raw_response
self.schema = schema
self.model = model
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
message = _mask_message("litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
model, raw_response, schema
)
))
self.message = message
super().__init__(model=model, message=message, llm_provider=llm_provider)
@ -701,7 +756,7 @@ class UnsupportedParamsError(BadRequestError):
num_retries: Optional[int] = None,
):
self.status_code = 400
self.message = "litellm.UnsupportedParamsError: {}".format(message)
self.message = _mask_message("litellm.UnsupportedParamsError: {}".format(message))
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
@ -746,7 +801,7 @@ class BudgetExceededError(Exception):
self.max_budget = max_budget
message = (
message
or f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
or _mask_message(f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}")
)
self.message = message
super().__init__(message)
@ -756,17 +811,17 @@ class BudgetExceededError(Exception):
class InvalidRequestError(openai.BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
self.message = _mask_message("litellm.InvalidRequestError: {}".format(message))
self.model = model
self.llm_provider = llm_provider
self.response = httpx.Response(
status_code=400,
response = httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__(
message=self.message, response=self.response, body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
@ -784,7 +839,7 @@ class MockException(openai.APIError):
num_retries: Optional[int] = None,
):
self.status_code = status_code
self.message = "litellm.MockException: {}".format(message)
self.message = _mask_message("litellm.MockException: {}".format(message))
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
@ -797,9 +852,9 @@ class MockException(openai.APIError):
class LiteLLMUnknownProvider(BadRequestError):
def __init__(self, model: str, custom_llm_provider: Optional[str] = None):
self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format(
self.message = _mask_message(LiteLLMCommonStrings.llm_provider_not_provided.value.format(
model=model, custom_llm_provider=custom_llm_provider
)
))
super().__init__(
self.message, model=model, llm_provider=custom_llm_provider, response=None
)

View file

@ -0,0 +1,59 @@
import unittest
import litellm
from litellm import exceptions
class TestExceptionsMasking(unittest.TestCase):
def test_api_key_masking_in_exceptions(self):
"""Test that API keys are properly masked in exception messages"""
# Test with a message containing an API key
api_key = "sk-12345678901234567890"
message = f"Failed to authenticate with API key {api_key}"
# Create an exception with this message
exception = exceptions.AuthenticationError(
message=message,
llm_provider="test_provider",
model="test_model"
)
# Check that the API key is not present in the exception message
self.assertNotIn(api_key, exception.message)
# Check that a masked version is present instead (should have the prefix and suffix)
self.assertIn("sk-1", exception.message)
self.assertIn("7890", exception.message)
def test_multiple_sensitive_keys_masked(self):
"""Test that multiple sensitive keys in the same message are masked"""
# Message with multiple sensitive information
message = (
"Error occurred. API key: sk-abc123def456, "
"Secret key: secret_xyz987, "
"Password: pass123word"
)
# Create an exception with this message
exception = exceptions.BadRequestError(
message=message,
model="test_model",
llm_provider="test_provider"
)
# Check that none of the sensitive data is present
self.assertNotIn("sk-abc123def456", exception.message)
self.assertNotIn("secret_xyz987", exception.message)
self.assertNotIn("pass123word", exception.message)
# Check that masked versions are present
self.assertIn("sk-a", exception.message)
self.assertIn("456", exception.message)
self.assertIn("secr", exception.message)
self.assertIn("987", exception.message)
self.assertIn("pass", exception.message)
self.assertIn("word", exception.message)
if __name__ == "__main__":
unittest.main()