forked from phoenix/litellm-mirror
fix(http_handler.py): mask gemini api key in error logs
Fixes https://github.com/BerriAI/litellm/issues/6963
This commit is contained in:
parent
7624cc45e6
commit
a2dc3cec95
2 changed files with 116 additions and 4 deletions
|
@ -28,6 +28,58 @@ headers = {
|
||||||
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||||
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
|
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def mask_sensitive_info(error_message):
|
||||||
|
# Find the start of the key parameter
|
||||||
|
key_index = error_message.find("key=")
|
||||||
|
|
||||||
|
# If key is found
|
||||||
|
if key_index != -1:
|
||||||
|
# Find the end of the key parameter (next & or end of string)
|
||||||
|
next_param = error_message.find("&", key_index)
|
||||||
|
|
||||||
|
if next_param == -1:
|
||||||
|
# If no more parameters, mask until the end of the string
|
||||||
|
masked_message = error_message[: key_index + 4] + "[REDACTED_API_KEY]"
|
||||||
|
else:
|
||||||
|
# Replace the key with redacted value, keeping other parameters
|
||||||
|
masked_message = (
|
||||||
|
error_message[: key_index + 4]
|
||||||
|
+ "[REDACTED_API_KEY]"
|
||||||
|
+ error_message[next_param:]
|
||||||
|
)
|
||||||
|
|
||||||
|
return masked_message
|
||||||
|
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedHTTPStatusError(httpx.HTTPStatusError):
|
||||||
|
def __init__(
|
||||||
|
self, original_error, message: Optional[str] = None, text: Optional[str] = None
|
||||||
|
):
|
||||||
|
# Create a new error with the masked URL
|
||||||
|
masked_url = mask_sensitive_info(str(original_error.request.url))
|
||||||
|
# Create a new error that looks like the original, but with a masked URL
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
message=original_error.message,
|
||||||
|
request=httpx.Request(
|
||||||
|
method=original_error.request.method,
|
||||||
|
url=masked_url,
|
||||||
|
headers=original_error.request.headers,
|
||||||
|
content=original_error.request.content,
|
||||||
|
),
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=original_error.response.status_code,
|
||||||
|
content=original_error.response.content,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.message = message
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
class AsyncHTTPHandler:
|
class AsyncHTTPHandler:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -155,13 +207,17 @@ class AsyncHTTPHandler:
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
setattr(e, "status_code", e.response.status_code)
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
setattr(e, "message", await e.response.aread())
|
setattr(e, "message", await e.response.aread())
|
||||||
setattr(e, "text", await e.response.aread())
|
setattr(e, "text", await e.response.aread())
|
||||||
else:
|
else:
|
||||||
setattr(e, "message", e.response.text)
|
setattr(e, "message", e.response.text)
|
||||||
setattr(e, "text", e.response.text)
|
setattr(e, "text", e.response.text)
|
||||||
|
e = MaskedHTTPStatusError(
|
||||||
|
e, message=getattr(e, "message", None), text=getattr(e, "text", None)
|
||||||
|
)
|
||||||
|
setattr(e, "status_code", e.response.status_code)
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -399,11 +455,20 @@ class HTTPHandler:
|
||||||
llm_provider="litellm-httpx-handler",
|
llm_provider="litellm-httpx-handler",
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
setattr(e, "status_code", e.response.status_code)
|
error_text = mask_sensitive_info(e.response.text)
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
setattr(e, "message", e.response.read())
|
setattr(e, "message", mask_sensitive_info(e.response.read()))
|
||||||
|
setattr(e, "text", mask_sensitive_info(e.response.read()))
|
||||||
else:
|
else:
|
||||||
setattr(e, "message", e.response.text)
|
setattr(e, "message", error_text)
|
||||||
|
setattr(e, "text", error_text)
|
||||||
|
|
||||||
|
e = MaskedHTTPStatusError(
|
||||||
|
e, message=getattr(e, "message", None), text=getattr(e, "text", None)
|
||||||
|
)
|
||||||
|
setattr(e, "status_code", e.response.status_code)
|
||||||
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -1032,3 +1032,50 @@ def test_get_end_user_id_for_cost_tracking(
|
||||||
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
|
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
|
||||||
== expected_end_user_id
|
== expected_end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sensitive_url_filtering(sync_mode):
|
||||||
|
"""ensure gemini api key not leaked in logs - Relevant Issue: https://github.com/BerriAI/litellm/issues/6963"""
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
|
import json
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
|
||||||
|
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||||
|
request_data = {
|
||||||
|
"input": [{"content": "hey, how's it going?"}],
|
||||||
|
"model": "text-embedding-004",
|
||||||
|
"max_tokens": 200, # invalid param
|
||||||
|
}
|
||||||
|
with pytest.raises(httpx.HTTPStatusError) as e:
|
||||||
|
if sync_mode:
|
||||||
|
client.post(
|
||||||
|
url=f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents?key={gemini_api_key}",
|
||||||
|
data=json.dumps(request_data),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await client.post(
|
||||||
|
url=f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents?key={gemini_api_key}",
|
||||||
|
data=json.dumps(request_data),
|
||||||
|
)
|
||||||
|
print(e.traceback)
|
||||||
|
print(f"exception received: {e._excinfo[1]}")
|
||||||
|
assert gemini_api_key not in str(e._excinfo[1])
|
||||||
|
|
||||||
|
with pytest.raises(litellm.BadRequestError) as e:
|
||||||
|
if sync_mode:
|
||||||
|
litellm.embedding(
|
||||||
|
model="gemini/text-embedding-004",
|
||||||
|
input="hey, how's it going?",
|
||||||
|
max_tokens=200,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await litellm.embedding(
|
||||||
|
model="gemini/text-embedding-004",
|
||||||
|
input="hey, how's it going?",
|
||||||
|
max_tokens=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "invalid json payload" in str(e._excinfo[1]).lower()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue