From a2dc3cec95bdc73da63a87fe8544f27bb9fbad2d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 29 Nov 2024 14:25:00 -0800 Subject: [PATCH] fix(http_handler.py): mask gemini api key in error logs Fixes https://github.com/BerriAI/litellm/issues/6963 --- litellm/llms/custom_httpx/http_handler.py | 73 +++++++++++++++++++++-- tests/local_testing/test_utils.py | 47 +++++++++++++++ 2 files changed, 116 insertions(+), 4 deletions(-) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index f5c4f694d..bbbfce0bf 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -28,6 +28,58 @@ headers = { _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) _DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour +import re + + +def mask_sensitive_info(error_message): + # Find the start of the key parameter + key_index = error_message.find("key=") + + # If key is found + if key_index != -1: + # Find the end of the key parameter (next & or end of string) + next_param = error_message.find("&", key_index) + + if next_param == -1: + # If no more parameters, mask until the end of the string + masked_message = error_message[: key_index + 4] + "[REDACTED_API_KEY]" + else: + # Replace the key with redacted value, keeping other parameters + masked_message = ( + error_message[: key_index + 4] + + "[REDACTED_API_KEY]" + + error_message[next_param:] + ) + + return masked_message + + return error_message + + +class MaskedHTTPStatusError(httpx.HTTPStatusError): + def __init__( + self, original_error, message: Optional[str] = None, text: Optional[str] = None + ): + # Create a new error with the masked URL + masked_url = mask_sensitive_info(str(original_error.request.url)) + # Create a new error that looks like the original, but with a masked URL + + super().__init__( + message=original_error.message, + request=httpx.Request( + method=original_error.request.method, + url=masked_url, + headers=original_error.request.headers, + content=original_error.request.content, + ), + response=httpx.Response( + status_code=original_error.response.status_code, + content=original_error.response.content, + ), + ) + self.message = message + self.text = text + class AsyncHTTPHandler: def __init__( @@ -155,13 +207,17 @@ class AsyncHTTPHandler: headers=headers, ) except httpx.HTTPStatusError as e: - setattr(e, "status_code", e.response.status_code) + if stream is True: setattr(e, "message", await e.response.aread()) setattr(e, "text", await e.response.aread()) else: setattr(e, "message", e.response.text) setattr(e, "text", e.response.text) + e = MaskedHTTPStatusError( + e, message=getattr(e, "message", None), text=getattr(e, "text", None) + ) + setattr(e, "status_code", e.response.status_code) raise e except Exception as e: raise e @@ -399,11 +455,20 @@ class HTTPHandler: llm_provider="litellm-httpx-handler", ) except httpx.HTTPStatusError as e: - setattr(e, "status_code", e.response.status_code) + error_text = mask_sensitive_info(e.response.text) + if stream is True: - setattr(e, "message", e.response.read()) + setattr(e, "message", mask_sensitive_info(e.response.read())) + setattr(e, "text", mask_sensitive_info(e.response.read())) else: - setattr(e, "message", e.response.text) + setattr(e, "message", error_text) + setattr(e, "text", error_text) + + e = MaskedHTTPStatusError( + e, message=getattr(e, "message", None), text=getattr(e, "text", None) + ) + setattr(e, "status_code", e.response.status_code) + raise e except Exception as e: raise e diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 7c349a658..c35f6ba77 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1032,3 +1032,50 @@ def test_get_end_user_id_for_cost_tracking( get_end_user_id_for_cost_tracking(litellm_params=litellm_params) == expected_end_user_id ) + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_sensitive_url_filtering(sync_mode): + """ensure gemini api key not leaked in logs - Relevant Issue: https://github.com/BerriAI/litellm/issues/6963""" + from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler + import json + import httpx + + client = HTTPHandler() if sync_mode else AsyncHTTPHandler() + gemini_api_key = os.getenv("GEMINI_API_KEY") + request_data = { + "input": [{"content": "hey, how's it going?"}], + "model": "text-embedding-004", + "max_tokens": 200, # invalid param + } + with pytest.raises(httpx.HTTPStatusError) as e: + if sync_mode: + client.post( + url=f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents?key={gemini_api_key}", + data=json.dumps(request_data), + ) + else: + await client.post( + url=f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:batchEmbedContents?key={gemini_api_key}", + data=json.dumps(request_data), + ) + print(e.traceback) + print(f"exception received: {e._excinfo[1]}") + assert gemini_api_key not in str(e._excinfo[1]) + + with pytest.raises(litellm.BadRequestError) as e: + if sync_mode: + litellm.embedding( + model="gemini/text-embedding-004", + input="hey, how's it going?", + max_tokens=200, + ) + else: + await litellm.embedding( + model="gemini/text-embedding-004", + input="hey, how's it going?", + max_tokens=200, + ) + + assert "invalid json payload" in str(e._excinfo[1]).lower()