forked from phoenix/litellm-mirror
fix(http_handler.py): mask gemini api key in error logs
Fixes https://github.com/BerriAI/litellm/issues/6963
This commit is contained in:
parent
7624cc45e6
commit
a2dc3cec95
2 changed files with 116 additions and 4 deletions
|
@ -28,6 +28,58 @@ headers = {
|
|||
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def mask_sensitive_info(error_message):
|
||||
# Find the start of the key parameter
|
||||
key_index = error_message.find("key=")
|
||||
|
||||
# If key is found
|
||||
if key_index != -1:
|
||||
# Find the end of the key parameter (next & or end of string)
|
||||
next_param = error_message.find("&", key_index)
|
||||
|
||||
if next_param == -1:
|
||||
# If no more parameters, mask until the end of the string
|
||||
masked_message = error_message[: key_index + 4] + "[REDACTED_API_KEY]"
|
||||
else:
|
||||
# Replace the key with redacted value, keeping other parameters
|
||||
masked_message = (
|
||||
error_message[: key_index + 4]
|
||||
+ "[REDACTED_API_KEY]"
|
||||
+ error_message[next_param:]
|
||||
)
|
||||
|
||||
return masked_message
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
class MaskedHTTPStatusError(httpx.HTTPStatusError):
|
||||
def __init__(
|
||||
self, original_error, message: Optional[str] = None, text: Optional[str] = None
|
||||
):
|
||||
# Create a new error with the masked URL
|
||||
masked_url = mask_sensitive_info(str(original_error.request.url))
|
||||
# Create a new error that looks like the original, but with a masked URL
|
||||
|
||||
super().__init__(
|
||||
message=original_error.message,
|
||||
request=httpx.Request(
|
||||
method=original_error.request.method,
|
||||
url=masked_url,
|
||||
headers=original_error.request.headers,
|
||||
content=original_error.request.content,
|
||||
),
|
||||
response=httpx.Response(
|
||||
status_code=original_error.response.status_code,
|
||||
content=original_error.response.content,
|
||||
),
|
||||
)
|
||||
self.message = message
|
||||
self.text = text
|
||||
|
||||
|
||||
class AsyncHTTPHandler:
|
||||
def __init__(
|
||||
|
@ -155,13 +207,17 @@ class AsyncHTTPHandler:
|
|||
headers=headers,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
if stream is True:
|
||||
setattr(e, "message", await e.response.aread())
|
||||
setattr(e, "text", await e.response.aread())
|
||||
else:
|
||||
setattr(e, "message", e.response.text)
|
||||
setattr(e, "text", e.response.text)
|
||||
e = MaskedHTTPStatusError(
|
||||
e, message=getattr(e, "message", None), text=getattr(e, "text", None)
|
||||
)
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -399,11 +455,20 @@ class HTTPHandler:
|
|||
llm_provider="litellm-httpx-handler",
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
error_text = mask_sensitive_info(e.response.text)
|
||||
|
||||
if stream is True:
|
||||
setattr(e, "message", e.response.read())
|
||||
setattr(e, "message", mask_sensitive_info(e.response.read()))
|
||||
setattr(e, "text", mask_sensitive_info(e.response.read()))
|
||||
else:
|
||||
setattr(e, "message", e.response.text)
|
||||
setattr(e, "message", error_text)
|
||||
setattr(e, "text", error_text)
|
||||
|
||||
e = MaskedHTTPStatusError(
|
||||
e, message=getattr(e, "message", None), text=getattr(e, "text", None)
|
||||
)
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -1032,3 +1032,50 @@ def test_get_end_user_id_for_cost_tracking(
|
|||
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
|
||||
== expected_end_user_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_sensitive_url_filtering(sync_mode):
|
||||
"""ensure gemini api key not leaked in logs - Relevant Issue: https://github.com/BerriAI/litellm/issues/6963"""
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
import json
|
||||
import httpx
|
||||
|
||||
client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
request_data = {
|
||||
"input": [{"content": "hey, how's it going?"}],
|
||||
"model": "text-embedding-004",
|
||||
"max_tokens": 200, # invalid param
|
||||
}
|
||||
with pytest.raises(httpx.HTTPStatusError) as e:
|
||||
if sync_mode:
|
||||
client.post(
|
||||
url=f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents?key={gemini_api_key}",
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
else:
|
||||
await client.post(
|
||||
url=f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents?key={gemini_api_key}",
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
print(e.traceback)
|
||||
print(f"exception received: {e._excinfo[1]}")
|
||||
assert gemini_api_key not in str(e._excinfo[1])
|
||||
|
||||
with pytest.raises(litellm.BadRequestError) as e:
|
||||
if sync_mode:
|
||||
litellm.embedding(
|
||||
model="gemini/text-embedding-004",
|
||||
input="hey, how's it going?",
|
||||
max_tokens=200,
|
||||
)
|
||||
else:
|
||||
await litellm.embedding(
|
||||
model="gemini/text-embedding-004",
|
||||
input="hey, how's it going?",
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
assert "invalid json payload" in str(e._excinfo[1]).lower()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue