mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
feat(exceptions): implement sensitive data masking in exception messages
This commit is contained in:
parent
f505716499
commit
c7b2596965
2 changed files with 140 additions and 26 deletions
|
@ -15,7 +15,62 @@ import httpx
|
|||
import openai
|
||||
|
||||
from litellm.types.utils import LiteLLMCommonStrings
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
|
||||
# Initialize a single SensitiveDataMasker instance to be used across all exception classes
|
||||
_sensitive_data_masker = SensitiveDataMasker()
|
||||
|
||||
def _mask_message(message):
|
||||
"""Helper function to mask sensitive data in exception messages"""
|
||||
if not message:
|
||||
return message
|
||||
|
||||
# Directly process the message string to mask sensitive patterns
|
||||
import re
|
||||
|
||||
# Common API key patterns (sk-, pk-, api-, etc.)
|
||||
patterns = [
|
||||
# OpenAI and similar keys
|
||||
r'sk-[a-zA-Z0-9]{10,}',
|
||||
r'sk_[a-zA-Z0-9]{10,}',
|
||||
# AWS keys
|
||||
r'AKIA[0-9A-Z]{16}',
|
||||
r'[a-zA-Z0-9+/]{40}', # AWS secret key pattern
|
||||
# Azure keys
|
||||
r'[a-zA-Z0-9]{32}',
|
||||
# Database connection strings
|
||||
r'(mongodb(\+srv)?:\/\/)[^:]+:[^@]+@[^\/]+',
|
||||
# API tokens and keys
|
||||
r'key-[a-zA-Z0-9]{24,}',
|
||||
r'token-[a-zA-Z0-9]{24,}',
|
||||
# Named keys and secrets
|
||||
r'secret_[a-zA-Z0-9]{5,}',
|
||||
r'pass[a-zA-Z0-9]{3,}word',
|
||||
# Generic patterns with capture groups
|
||||
r'(API key[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
|
||||
r'(api[_-]?key[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
|
||||
r'(secret[_-]?key[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
|
||||
r'(access[_-]?key[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
|
||||
r'(password[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
|
||||
r'(token[:=]?\s*)[\'"]?([a-zA-Z0-9_\-\.]{6,})[\'"]?',
|
||||
]
|
||||
|
||||
# Apply masking
|
||||
masked_message = message
|
||||
for pattern in patterns:
|
||||
if '(' in pattern and ')' in pattern: # Has capturing groups
|
||||
# For patterns with capturing groups, keep the prefix and mask the value
|
||||
def replace_func(match):
|
||||
if len(match.groups()) > 1:
|
||||
return match.group(1) + _sensitive_data_masker._mask_value(match.group(2))
|
||||
return _sensitive_data_masker._mask_value(match.group(0))
|
||||
|
||||
masked_message = re.sub(pattern, replace_func, masked_message)
|
||||
else:
|
||||
# For patterns without capturing groups, mask the entire match
|
||||
masked_message = re.sub(pattern, lambda m: _sensitive_data_masker._mask_value(m.group(0)), masked_message)
|
||||
|
||||
return masked_message
|
||||
|
||||
class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||
def __init__(
|
||||
|
@ -29,7 +84,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 401
|
||||
self.message = "litellm.AuthenticationError: {}".format(message)
|
||||
self.message = _mask_message("litellm.AuthenticationError: {}".format(message))
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -75,7 +130,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 404
|
||||
self.message = "litellm.NotFoundError: {}".format(message)
|
||||
self.message = _mask_message("litellm.NotFoundError: {}".format(message))
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -121,7 +176,7 @@ class BadRequestError(openai.BadRequestError): # type: ignore
|
|||
body: Optional[dict] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = "litellm.BadRequestError: {}".format(message)
|
||||
self.message = _mask_message("litellm.BadRequestError: {}".format(message))
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -166,7 +221,7 @@ class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 422
|
||||
self.message = "litellm.UnprocessableEntityError: {}".format(message)
|
||||
self.message = _mask_message("litellm.UnprocessableEntityError: {}".format(message))
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -212,7 +267,7 @@ class Timeout(openai.APITimeoutError): # type: ignore
|
|||
request=request
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
self.status_code = 408
|
||||
self.message = "litellm.Timeout: {}".format(message)
|
||||
self.message = _mask_message("litellm.Timeout: {}".format(message))
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -250,7 +305,7 @@ class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 403
|
||||
self.message = "litellm.PermissionDeniedError: {}".format(message)
|
||||
self.message = _mask_message("litellm.PermissionDeniedError: {}".format(message))
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -289,7 +344,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 429
|
||||
self.message = "litellm.RateLimitError: {}".format(message)
|
||||
self.message = _mask_message("litellm.RateLimitError: {}".format(message))
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -354,7 +409,7 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
# set after, to make it clear the raised error is a context window exceeded error
|
||||
self.message = "litellm.ContextWindowExceededError: {}".format(self.message)
|
||||
self.message = _mask_message("litellm.ContextWindowExceededError: {}".format(self.message))
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
|
@ -384,7 +439,7 @@ class RejectedRequestError(BadRequestError): # type: ignore
|
|||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = "litellm.RejectedRequestError: {}".format(message)
|
||||
self.message = _mask_message("litellm.RejectedRequestError: {}".format(message))
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -427,7 +482,7 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
|
|||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = "litellm.ContentPolicyViolationError: {}".format(message)
|
||||
self.message = _mask_message("litellm.ContentPolicyViolationError: {}".format(message))
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -470,7 +525,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 503
|
||||
self.message = "litellm.ServiceUnavailableError: {}".format(message)
|
||||
self.message = _mask_message("litellm.ServiceUnavailableError: {}".format(message))
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -516,7 +571,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 500
|
||||
self.message = "litellm.InternalServerError: {}".format(message)
|
||||
self.message = _mask_message("litellm.InternalServerError: {}".format(message))
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -564,7 +619,7 @@ class APIError(openai.APIError): # type: ignore
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = "litellm.APIError: {}".format(message)
|
||||
self.message = _mask_message("litellm.APIError: {}".format(message))
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -603,7 +658,7 @@ class APIConnectionError(openai.APIConnectionError): # type: ignore
|
|||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.message = "litellm.APIConnectionError: {}".format(message)
|
||||
self.message = _mask_message("litellm.APIConnectionError: {}".format(message))
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.status_code = 500
|
||||
|
@ -641,7 +696,7 @@ class APIResponseValidationError(openai.APIResponseValidationError): # type: ig
|
|||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.message = "litellm.APIResponseValidationError: {}".format(message)
|
||||
self.message = _mask_message("litellm.APIResponseValidationError: {}".format(message))
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
|
@ -675,9 +730,9 @@ class JSONSchemaValidationError(APIResponseValidationError):
|
|||
self.raw_response = raw_response
|
||||
self.schema = schema
|
||||
self.model = model
|
||||
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
|
||||
message = _mask_message("litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
|
||||
model, raw_response, schema
|
||||
)
|
||||
))
|
||||
self.message = message
|
||||
super().__init__(model=model, message=message, llm_provider=llm_provider)
|
||||
|
||||
|
@ -701,7 +756,7 @@ class UnsupportedParamsError(BadRequestError):
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = "litellm.UnsupportedParamsError: {}".format(message)
|
||||
self.message = _mask_message("litellm.UnsupportedParamsError: {}".format(message))
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -746,7 +801,7 @@ class BudgetExceededError(Exception):
|
|||
self.max_budget = max_budget
|
||||
message = (
|
||||
message
|
||||
or f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
|
||||
or _mask_message(f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}")
|
||||
)
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
@ -756,17 +811,17 @@ class BudgetExceededError(Exception):
|
|||
class InvalidRequestError(openai.BadRequestError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.message = _mask_message("litellm.InvalidRequestError: {}".format(message))
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.response = httpx.Response(
|
||||
status_code=400,
|
||||
response = httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
super().__init__(
|
||||
message=self.message, response=self.response, body=None
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
|
@ -784,7 +839,7 @@ class MockException(openai.APIError):
|
|||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = "litellm.MockException: {}".format(message)
|
||||
self.message = _mask_message("litellm.MockException: {}".format(message))
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
@ -797,9 +852,9 @@ class MockException(openai.APIError):
|
|||
|
||||
class LiteLLMUnknownProvider(BadRequestError):
|
||||
def __init__(self, model: str, custom_llm_provider: Optional[str] = None):
|
||||
self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format(
|
||||
self.message = _mask_message(LiteLLMCommonStrings.llm_provider_not_provided.value.format(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
))
|
||||
super().__init__(
|
||||
self.message, model=model, llm_provider=custom_llm_provider, response=None
|
||||
)
|
||||
|
|
59
tests/test_exceptions_masking.py
Normal file
59
tests/test_exceptions_masking.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import unittest
|
||||
import litellm
|
||||
from litellm import exceptions
|
||||
|
||||
|
||||
class TestExceptionsMasking(unittest.TestCase):
|
||||
def test_api_key_masking_in_exceptions(self):
|
||||
"""Test that API keys are properly masked in exception messages"""
|
||||
|
||||
# Test with a message containing an API key
|
||||
api_key = "sk-12345678901234567890"
|
||||
message = f"Failed to authenticate with API key {api_key}"
|
||||
|
||||
# Create an exception with this message
|
||||
exception = exceptions.AuthenticationError(
|
||||
message=message,
|
||||
llm_provider="test_provider",
|
||||
model="test_model"
|
||||
)
|
||||
|
||||
# Check that the API key is not present in the exception message
|
||||
self.assertNotIn(api_key, exception.message)
|
||||
# Check that a masked version is present instead (should have the prefix and suffix)
|
||||
self.assertIn("sk-1", exception.message)
|
||||
self.assertIn("7890", exception.message)
|
||||
|
||||
def test_multiple_sensitive_keys_masked(self):
|
||||
"""Test that multiple sensitive keys in the same message are masked"""
|
||||
|
||||
# Message with multiple sensitive information
|
||||
message = (
|
||||
"Error occurred. API key: sk-abc123def456, "
|
||||
"Secret key: secret_xyz987, "
|
||||
"Password: pass123word"
|
||||
)
|
||||
|
||||
# Create an exception with this message
|
||||
exception = exceptions.BadRequestError(
|
||||
message=message,
|
||||
model="test_model",
|
||||
llm_provider="test_provider"
|
||||
)
|
||||
|
||||
# Check that none of the sensitive data is present
|
||||
self.assertNotIn("sk-abc123def456", exception.message)
|
||||
self.assertNotIn("secret_xyz987", exception.message)
|
||||
self.assertNotIn("pass123word", exception.message)
|
||||
|
||||
# Check that masked versions are present
|
||||
self.assertIn("sk-a", exception.message)
|
||||
self.assertIn("456", exception.message)
|
||||
self.assertIn("secr", exception.message)
|
||||
self.assertIn("987", exception.message)
|
||||
self.assertIn("pass", exception.message)
|
||||
self.assertIn("word", exception.message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue