fix(http_handler.py): mask gemini api key in error logs

Fixes https://github.com/BerriAI/litellm/issues/6963
This commit is contained in:
Krrish Dholakia 2024-11-29 14:25:00 -08:00
parent 7624cc45e6
commit a2dc3cec95
2 changed files with 116 additions and 4 deletions

View file

@ -28,6 +28,58 @@ headers = {
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour _DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
import re
def mask_sensitive_info(error_message):
# Find the start of the key parameter
key_index = error_message.find("key=")
# If key is found
if key_index != -1:
# Find the end of the key parameter (next & or end of string)
next_param = error_message.find("&", key_index)
if next_param == -1:
# If no more parameters, mask until the end of the string
masked_message = error_message[: key_index + 4] + "[REDACTED_API_KEY]"
else:
# Replace the key with redacted value, keeping other parameters
masked_message = (
error_message[: key_index + 4]
+ "[REDACTED_API_KEY]"
+ error_message[next_param:]
)
return masked_message
return error_message
class MaskedHTTPStatusError(httpx.HTTPStatusError):
def __init__(
self, original_error, message: Optional[str] = None, text: Optional[str] = None
):
# Create a new error with the masked URL
masked_url = mask_sensitive_info(str(original_error.request.url))
# Create a new error that looks like the original, but with a masked URL
super().__init__(
message=original_error.message,
request=httpx.Request(
method=original_error.request.method,
url=masked_url,
headers=original_error.request.headers,
content=original_error.request.content,
),
response=httpx.Response(
status_code=original_error.response.status_code,
content=original_error.response.content,
),
)
self.message = message
self.text = text
class AsyncHTTPHandler: class AsyncHTTPHandler:
def __init__( def __init__(
@ -155,13 +207,17 @@ class AsyncHTTPHandler:
headers=headers, headers=headers,
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
setattr(e, "status_code", e.response.status_code)
if stream is True: if stream is True:
setattr(e, "message", await e.response.aread()) setattr(e, "message", await e.response.aread())
setattr(e, "text", await e.response.aread()) setattr(e, "text", await e.response.aread())
else: else:
setattr(e, "message", e.response.text) setattr(e, "message", e.response.text)
setattr(e, "text", e.response.text) setattr(e, "text", e.response.text)
e = MaskedHTTPStatusError(
e, message=getattr(e, "message", None), text=getattr(e, "text", None)
)
setattr(e, "status_code", e.response.status_code)
raise e raise e
except Exception as e: except Exception as e:
raise e raise e
@ -399,11 +455,20 @@ class HTTPHandler:
llm_provider="litellm-httpx-handler", llm_provider="litellm-httpx-handler",
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
setattr(e, "status_code", e.response.status_code) error_text = mask_sensitive_info(e.response.text)
if stream is True: if stream is True:
setattr(e, "message", e.response.read()) setattr(e, "message", mask_sensitive_info(e.response.read()))
setattr(e, "text", mask_sensitive_info(e.response.read()))
else: else:
setattr(e, "message", e.response.text) setattr(e, "message", error_text)
setattr(e, "text", error_text)
e = MaskedHTTPStatusError(
e, message=getattr(e, "message", None), text=getattr(e, "text", None)
)
setattr(e, "status_code", e.response.status_code)
raise e raise e
except Exception as e: except Exception as e:
raise e raise e

View file

@ -1032,3 +1032,50 @@ def test_get_end_user_id_for_cost_tracking(
get_end_user_id_for_cost_tracking(litellm_params=litellm_params) get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
== expected_end_user_id == expected_end_user_id
) )
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_sensitive_url_filtering(sync_mode):
"""ensure gemini api key not leaked in logs - Relevant Issue: https://github.com/BerriAI/litellm/issues/6963"""
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
import json
import httpx
client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
gemini_api_key = os.getenv("GEMINI_API_KEY")
request_data = {
"input": [{"content": "hey, how's it going?"}],
"model": "text-embedding-004",
"max_tokens": 200, # invalid param
}
with pytest.raises(httpx.HTTPStatusError) as e:
if sync_mode:
client.post(
url=f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents?key={gemini_api_key}",
data=json.dumps(request_data),
)
else:
await client.post(
url=f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents?key={gemini_api_key}",
data=json.dumps(request_data),
)
print(e.traceback)
print(f"exception received: {e._excinfo[1]}")
assert gemini_api_key not in str(e._excinfo[1])
with pytest.raises(litellm.BadRequestError) as e:
if sync_mode:
litellm.embedding(
model="gemini/text-embedding-004",
input="hey, how's it going?",
max_tokens=200,
)
else:
await litellm.embedding(
model="gemini/text-embedding-004",
input="hey, how's it going?",
max_tokens=200,
)
assert "invalid json payload" in str(e._excinfo[1]).lower()