fix(test_custom_callbacks_input.py): unit tests for 'turn_off_message_logging'

ensure no raw request is logged either
This commit is contained in:
Krrish Dholakia 2024-06-07 15:39:15 -07:00
parent 51fb199329
commit f73b6033fd
4 changed files with 72 additions and 9 deletions

View file

@ -471,10 +471,14 @@ def mock_completion(
try: try:
_, custom_llm_provider, _, _ = litellm.utils.get_llm_provider(model=model) _, custom_llm_provider, _, _ = litellm.utils.get_llm_provider(model=model)
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
except: except Exception:
# dont let setting a hidden param block a mock_respose # dont let setting a hidden param block a mock_respose
pass pass
logging.post_call(
input=messages,
api_key="my-secret-key",
original_response="my-original-response",
)
return model_response return model_response
except Exception as e: except Exception as e:

View file

@ -10,6 +10,7 @@ from typing import Optional, Literal, List, Union
from litellm import completion, embedding, Cache from litellm import completion, embedding, Cache
import litellm import litellm
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import LiteLLMCommonStrings
# Test Scenarios (test across completion, streaming, embedding) # Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call ## 1: Pre-API-Call
@ -67,7 +68,18 @@ class CompletionCustomHandler(
assert isinstance(kwargs["start_time"], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
except Exception as e: ### METADATA
metadata_value = kwargs["litellm_params"].get("metadata")
assert metadata_value is None or isinstance(metadata_value, dict)
if metadata_value is not None:
if litellm.turn_off_message_logging is True:
assert (
metadata_value["raw_request"]
is LiteLLMCommonStrings.redacted_by_litellm.value
)
else:
assert isinstance(metadata_value["raw_request"], str)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -177,6 +189,8 @@ class CompletionCustomHandler(
assert isinstance( assert isinstance(
kwargs["original_response"], kwargs["original_response"],
(str, litellm.CustomStreamWrapper, BaseModel), (str, litellm.CustomStreamWrapper, BaseModel),
), "Original Response={}. Allowed types=[str, litellm.CustomStreamWrapper, BaseModel]".format(
kwargs["original_response"]
) )
assert isinstance(kwargs["additional_args"], (dict, type(None))) assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str) assert isinstance(kwargs["log_event_type"], str)
@ -1053,3 +1067,25 @@ def test_image_generation_openai():
## Test Azure + Sync ## Test Azure + Sync
## Test Azure + Async ## Test Azure + Async
##### PII REDACTION ######
def test_turn_off_message_logging():
"""
If 'turn_off_message_logging' is true, assert no user request information is logged.
"""
litellm.turn_off_message_logging = True
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="Going well!",
)
time.sleep(2)
assert len(customHandler.errors) == 0

View file

@ -1,5 +1,10 @@
from typing import List, Optional, Union, Dict, Tuple, Literal from typing import List, Optional, Union, Dict, Tuple, Literal
from typing_extensions import TypedDict from typing_extensions import TypedDict
from enum import Enum
class LiteLLMCommonStrings(Enum):
redacted_by_litellm = "redacted by litellm. 'litellm.turn_off_message_logging=True'"
class CostPerToken(TypedDict): class CostPerToken(TypedDict):

View file

@ -1308,14 +1308,28 @@ class Logging:
) )
else: else:
verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n") verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n")
# check if user wants the raw request logged to their logging provider (like LangFuse) # log raw request to provider (like LangFuse)
try: try:
# [Non-blocking Extra Debug Information in metadata] # [Non-blocking Extra Debug Information in metadata]
_litellm_params = self.model_call_details.get("litellm_params", {}) _litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {} _metadata = _litellm_params.get("metadata", {}) or {}
_metadata["raw_request"] = str(curl_command) if (
except: litellm.turn_off_message_logging is not None
pass and litellm.turn_off_message_logging is True
):
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
_metadata["raw_request"] = str(curl_command)
except Exception as e:
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
str(e)
)
)
if self.logger_fn and callable(self.logger_fn): if self.logger_fn and callable(self.logger_fn):
try: try:
self.logger_fn( self.logger_fn(
@ -2684,7 +2698,9 @@ class Logging:
# check if user opted out of logging message/response to callbacks # check if user opted out of logging message/response to callbacks
if litellm.turn_off_message_logging == True: if litellm.turn_off_message_logging == True:
# remove messages, prompts, input, response from logging # remove messages, prompts, input, response from logging
self.model_call_details["messages"] = "redacted-by-litellm" self.model_call_details["messages"] = [
{"role": "user", "content": "redacted-by-litellm"}
]
self.model_call_details["prompt"] = "" self.model_call_details["prompt"] = ""
self.model_call_details["input"] = "" self.model_call_details["input"] = ""
@ -4064,7 +4080,9 @@ def openai_token_counter(
for c in value: for c in value:
if c["type"] == "text": if c["type"] == "text":
text += c["text"] text += c["text"]
num_tokens += len(encoding.encode(c["text"], disallowed_special=())) num_tokens += len(
encoding.encode(c["text"], disallowed_special=())
)
elif c["type"] == "image_url": elif c["type"] == "image_url":
if isinstance(c["image_url"], dict): if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"] image_url_dict = c["image_url"]