should_run_prompt_management_hooks

This commit is contained in:
Ishaan Jaff 2025-04-14 16:34:08 -07:00
parent d986b5d6b1
commit e64254b381
2 changed files with 167 additions and 149 deletions

View file

@ -249,9 +249,9 @@ class Logging(LiteLLMLoggingBaseClass):
self.litellm_trace_id = litellm_trace_id self.litellm_trace_id = litellm_trace_id
self.function_id = function_id self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[ self.sync_streaming_chunks: List[Any] = (
Any []
] = [] # for generating complete stream response ) # for generating complete stream response
self.log_raw_request_response = log_raw_request_response self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks # Initialize dynamic callbacks
@ -455,6 +455,20 @@ class Logging(LiteLLMLoggingBaseClass):
if "custom_llm_provider" in self.model_call_details: if "custom_llm_provider" in self.model_call_details:
self.custom_llm_provider = self.model_call_details["custom_llm_provider"] self.custom_llm_provider = self.model_call_details["custom_llm_provider"]
def should_run_prompt_management_hooks(
self,
prompt_id: str,
kwargs: Dict,
) -> bool:
"""
Return True if prompt management hooks should be run
"""
if prompt_id:
return True
if kwargs.get("inject_cache_control_breakpoints_locations", None):
return True
return False
def get_chat_completion_prompt( def get_chat_completion_prompt(
self, self,
model: str, model: str,
@ -557,9 +571,9 @@ class Logging(LiteLLMLoggingBaseClass):
model model
): # if model name was changes pre-call, overwrite the initial model call name with the new one ): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model self.model_call_details["model"] = model
self.model_call_details["litellm_params"][ self.model_call_details["litellm_params"]["api_base"] = (
"api_base" self._get_masked_api_base(additional_args.get("api_base", ""))
] = self._get_masked_api_base(additional_args.get("api_base", "")) )
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API # Log the exact input to the LLM API
@ -588,10 +602,10 @@ class Logging(LiteLLMLoggingBaseClass):
try: try:
# [Non-blocking Extra Debug Information in metadata] # [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True: if turn_off_message_logging is True:
_metadata[ _metadata["raw_request"] = (
"raw_request" "redacted by litellm. \
] = "redacted by litellm. \
'litellm.turn_off_message_logging=True'" 'litellm.turn_off_message_logging=True'"
)
else: else:
curl_command = self._get_request_curl_command( curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""), api_base=additional_args.get("api_base", ""),
@ -602,32 +616,32 @@ class Logging(LiteLLMLoggingBaseClass):
_metadata["raw_request"] = str(curl_command) _metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI # split up, so it's easier to parse in the UI
self.model_call_details[ self.model_call_details["raw_request_typed_dict"] = (
"raw_request_typed_dict" RawRequestTypedDict(
] = RawRequestTypedDict( raw_request_api_base=str(
raw_request_api_base=str( additional_args.get("api_base") or ""
additional_args.get("api_base") or "" ),
), raw_request_body=self._get_raw_request_body(
raw_request_body=self._get_raw_request_body( additional_args.get("complete_input_dict", {})
additional_args.get("complete_input_dict", {}) ),
), raw_request_headers=self._get_masked_headers(
raw_request_headers=self._get_masked_headers( additional_args.get("headers", {}) or {},
additional_args.get("headers", {}) or {}, ignore_sensitive_headers=True,
ignore_sensitive_headers=True, ),
), error=None,
error=None, )
) )
except Exception as e: except Exception as e:
self.model_call_details[ self.model_call_details["raw_request_typed_dict"] = (
"raw_request_typed_dict" RawRequestTypedDict(
] = RawRequestTypedDict( error=str(e),
error=str(e), )
) )
_metadata[ _metadata["raw_request"] = (
"raw_request" "Unable to Log \
] = "Unable to Log \
raw request: {}".format( raw request: {}".format(
str(e) str(e)
)
) )
if self.logger_fn and callable(self.logger_fn): if self.logger_fn and callable(self.logger_fn):
try: try:
@ -957,9 +971,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug( verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}" f"response_cost_failure_debug_information: {debug_info}"
) )
self.model_call_details[ self.model_call_details["response_cost_failure_debug_information"] = (
"response_cost_failure_debug_information" debug_info
] = debug_info )
return None return None
try: try:
@ -984,9 +998,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug( verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}" f"response_cost_failure_debug_information: {debug_info}"
) )
self.model_call_details[ self.model_call_details["response_cost_failure_debug_information"] = (
"response_cost_failure_debug_information" debug_info
] = debug_info )
return None return None
@ -1046,9 +1060,9 @@ class Logging(LiteLLMLoggingBaseClass):
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
if self.completion_start_time is None: if self.completion_start_time is None:
self.completion_start_time = end_time self.completion_start_time = end_time
self.model_call_details[ self.model_call_details["completion_start_time"] = (
"completion_start_time" self.completion_start_time
] = self.completion_start_time )
self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit self.model_call_details["cache_hit"] = cache_hit
@ -1127,39 +1141,39 @@ class Logging(LiteLLMLoggingBaseClass):
"response_cost" "response_cost"
] ]
else: else:
self.model_call_details[ self.model_call_details["response_cost"] = (
"response_cost" self._response_cost_calculator(result=logging_result)
] = self._response_cost_calculator(result=logging_result) )
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[ self.model_call_details["standard_logging_object"] = (
"standard_logging_object" get_standard_logging_object_payload(
] = get_standard_logging_object_payload( kwargs=self.model_call_details,
kwargs=self.model_call_details, init_response_obj=logging_result,
init_response_obj=logging_result, start_time=start_time,
start_time=start_time, end_time=end_time,
end_time=end_time, logging_obj=self,
logging_obj=self, status="success",
status="success", standard_built_in_tools_params=self.standard_built_in_tools_params,
standard_built_in_tools_params=self.standard_built_in_tools_params, )
) )
elif isinstance(result, dict) or isinstance(result, list): elif isinstance(result, dict) or isinstance(result, list):
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[ self.model_call_details["standard_logging_object"] = (
"standard_logging_object" get_standard_logging_object_payload(
] = get_standard_logging_object_payload( kwargs=self.model_call_details,
kwargs=self.model_call_details, init_response_obj=result,
init_response_obj=result, start_time=start_time,
start_time=start_time, end_time=end_time,
end_time=end_time, logging_obj=self,
logging_obj=self, status="success",
status="success", standard_built_in_tools_params=self.standard_built_in_tools_params,
standard_built_in_tools_params=self.standard_built_in_tools_params, )
) )
elif standard_logging_object is not None: elif standard_logging_object is not None:
self.model_call_details[ self.model_call_details["standard_logging_object"] = (
"standard_logging_object" standard_logging_object
] = standard_logging_object )
else: # streaming chunks + image gen. else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None self.model_call_details["response_cost"] = None
@ -1215,23 +1229,23 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug( verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete" "Logging Details LiteLLM-Success Call streaming complete"
) )
self.model_call_details[ self.model_call_details["complete_streaming_response"] = (
"complete_streaming_response" complete_streaming_response
] = complete_streaming_response )
self.model_call_details[ self.model_call_details["response_cost"] = (
"response_cost" self._response_cost_calculator(result=complete_streaming_response)
] = self._response_cost_calculator(result=complete_streaming_response) )
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[ self.model_call_details["standard_logging_object"] = (
"standard_logging_object" get_standard_logging_object_payload(
] = get_standard_logging_object_payload( kwargs=self.model_call_details,
kwargs=self.model_call_details, init_response_obj=complete_streaming_response,
init_response_obj=complete_streaming_response, start_time=start_time,
start_time=start_time, end_time=end_time,
end_time=end_time, logging_obj=self,
logging_obj=self, status="success",
status="success", standard_built_in_tools_params=self.standard_built_in_tools_params,
standard_built_in_tools_params=self.standard_built_in_tools_params, )
) )
callbacks = self.get_combined_callback_list( callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_success_callbacks, dynamic_success_callbacks=self.dynamic_success_callbacks,
@ -1580,10 +1594,10 @@ class Logging(LiteLLMLoggingBaseClass):
) )
else: else:
if self.stream and complete_streaming_response: if self.stream and complete_streaming_response:
self.model_call_details[ self.model_call_details["complete_response"] = (
"complete_response" self.model_call_details.get(
] = self.model_call_details.get( "complete_streaming_response", {}
"complete_streaming_response", {} )
) )
result = self.model_call_details["complete_response"] result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event( openMeterLogger.log_success_event(
@ -1623,10 +1637,10 @@ class Logging(LiteLLMLoggingBaseClass):
) )
else: else:
if self.stream and complete_streaming_response: if self.stream and complete_streaming_response:
self.model_call_details[ self.model_call_details["complete_response"] = (
"complete_response" self.model_call_details.get(
] = self.model_call_details.get( "complete_streaming_response", {}
"complete_streaming_response", {} )
) )
result = self.model_call_details["complete_response"] result = self.model_call_details["complete_response"]
@ -1733,9 +1747,9 @@ class Logging(LiteLLMLoggingBaseClass):
if complete_streaming_response is not None: if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response") print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details[ self.model_call_details["async_complete_streaming_response"] = (
"async_complete_streaming_response" complete_streaming_response
] = complete_streaming_response )
try: try:
if self.model_call_details.get("cache_hit", False) is True: if self.model_call_details.get("cache_hit", False) is True:
self.model_call_details["response_cost"] = 0.0 self.model_call_details["response_cost"] = 0.0
@ -1745,10 +1759,10 @@ class Logging(LiteLLMLoggingBaseClass):
model_call_details=self.model_call_details model_call_details=self.model_call_details
) )
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info
self.model_call_details[ self.model_call_details["response_cost"] = (
"response_cost" self._response_cost_calculator(
] = self._response_cost_calculator( result=complete_streaming_response
result=complete_streaming_response )
) )
verbose_logger.debug( verbose_logger.debug(
@ -1761,16 +1775,16 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["response_cost"] = None self.model_call_details["response_cost"] = None
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[ self.model_call_details["standard_logging_object"] = (
"standard_logging_object" get_standard_logging_object_payload(
] = get_standard_logging_object_payload( kwargs=self.model_call_details,
kwargs=self.model_call_details, init_response_obj=complete_streaming_response,
init_response_obj=complete_streaming_response, start_time=start_time,
start_time=start_time, end_time=end_time,
end_time=end_time, logging_obj=self,
logging_obj=self, status="success",
status="success", standard_built_in_tools_params=self.standard_built_in_tools_params,
standard_built_in_tools_params=self.standard_built_in_tools_params, )
) )
callbacks = self.get_combined_callback_list( callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_success_callbacks, dynamic_success_callbacks=self.dynamic_async_success_callbacks,
@ -1976,18 +1990,18 @@ class Logging(LiteLLMLoggingBaseClass):
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[ self.model_call_details["standard_logging_object"] = (
"standard_logging_object" get_standard_logging_object_payload(
] = get_standard_logging_object_payload( kwargs=self.model_call_details,
kwargs=self.model_call_details, init_response_obj={},
init_response_obj={}, start_time=start_time,
start_time=start_time, end_time=end_time,
end_time=end_time, logging_obj=self,
logging_obj=self, status="failure",
status="failure", error_str=str(exception),
error_str=str(exception), original_exception=exception,
original_exception=exception, standard_built_in_tools_params=self.standard_built_in_tools_params,
standard_built_in_tools_params=self.standard_built_in_tools_params, )
) )
return start_time, end_time return start_time, end_time
@ -2753,9 +2767,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
endpoint=arize_config.endpoint, endpoint=arize_config.endpoint,
) )
os.environ[ os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
"OTEL_EXPORTER_OTLP_TRACES_HEADERS" f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}" )
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if ( if (
isinstance(callback, ArizeLogger) isinstance(callback, ArizeLogger)
@ -2779,9 +2793,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
# auth can be disabled on local deployments of arize phoenix # auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None: if arize_phoenix_config.otlp_auth_headers is not None:
os.environ[ os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
"OTEL_EXPORTER_OTLP_TRACES_HEADERS" arize_phoenix_config.otlp_auth_headers
] = arize_phoenix_config.otlp_auth_headers )
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if ( if (
@ -2872,9 +2886,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
exporter="otlp_http", exporter="otlp_http",
endpoint="https://langtrace.ai/api/trace", endpoint="https://langtrace.ai/api/trace",
) )
os.environ[ os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
"OTEL_EXPORTER_OTLP_TRACES_HEADERS" f"api_key={os.getenv('LANGTRACE_API_KEY')}"
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" )
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if ( if (
isinstance(callback, OpenTelemetry) isinstance(callback, OpenTelemetry)
@ -3369,10 +3383,10 @@ class StandardLoggingPayloadSetup:
for key in StandardLoggingHiddenParams.__annotations__.keys(): for key in StandardLoggingHiddenParams.__annotations__.keys():
if key in hidden_params: if key in hidden_params:
if key == "additional_headers": if key == "additional_headers":
clean_hidden_params[ clean_hidden_params["additional_headers"] = (
"additional_headers" StandardLoggingPayloadSetup.get_additional_headers(
] = StandardLoggingPayloadSetup.get_additional_headers( hidden_params[key]
hidden_params[key] )
) )
else: else:
clean_hidden_params[key] = hidden_params[key] # type: ignore clean_hidden_params[key] = hidden_params[key] # type: ignore
@ -3651,7 +3665,7 @@ def emit_standard_logging_payload(payload: StandardLoggingPayload):
def get_standard_logging_metadata( def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]] metadata: Optional[Dict[str, Any]],
) -> StandardLoggingMetadata: ) -> StandardLoggingMetadata:
""" """
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@ -3715,9 +3729,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
): ):
for k, v in metadata["user_api_key_metadata"].items(): for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[ cleaned_user_api_key_metadata[k] = (
k "scrubbed_by_litellm_for_sensitive_keys"
] = "scrubbed_by_litellm_for_sensitive_keys" )
else: else:
cleaned_user_api_key_metadata[k] = v cleaned_user_api_key_metadata[k] = v

View file

@ -954,7 +954,11 @@ def completion( # type: ignore # noqa: PLR0915
non_default_params = get_non_default_completion_params(kwargs=kwargs) non_default_params = get_non_default_completion_params(kwargs=kwargs)
litellm_params = {} # used to prevent unbound var errors litellm_params = {} # used to prevent unbound var errors
## PROMPT MANAGEMENT HOOKS ## ## PROMPT MANAGEMENT HOOKS ##
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None: if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and (
litellm_logging_obj.should_run_prompt_management_hooks(
prompt_id=prompt_id, kwargs=kwargs
)
):
( (
model, model,
messages, messages,
@ -2654,9 +2658,9 @@ def completion( # type: ignore # noqa: PLR0915
"aws_region_name" not in optional_params "aws_region_name" not in optional_params
or optional_params["aws_region_name"] is None or optional_params["aws_region_name"] is None
): ):
optional_params[ optional_params["aws_region_name"] = (
"aws_region_name" aws_bedrock_client.meta.region_name
] = aws_bedrock_client.meta.region_name )
bedrock_route = BedrockModelInfo.get_bedrock_route(model) bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse": if bedrock_route == "converse":
@ -4363,9 +4367,9 @@ def adapter_completion(
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
translated_response: Optional[ translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
Union[BaseModel, AdapterCompletionStreamWrapper] None
] = None )
if isinstance(response, ModelResponse): if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params( translated_response = translation_obj.translate_completion_output_params(
response=response response=response
@ -5785,9 +5789,9 @@ def stream_chunk_builder( # noqa: PLR0915
] ]
if len(content_chunks) > 0: if len(content_chunks) > 0:
response["choices"][0]["message"][ response["choices"][0]["message"]["content"] = (
"content" processor.get_combined_content(content_chunks)
] = processor.get_combined_content(content_chunks) )
reasoning_chunks = [ reasoning_chunks = [
chunk chunk
@ -5798,9 +5802,9 @@ def stream_chunk_builder( # noqa: PLR0915
] ]
if len(reasoning_chunks) > 0: if len(reasoning_chunks) > 0:
response["choices"][0]["message"][ response["choices"][0]["message"]["reasoning_content"] = (
"reasoning_content" processor.get_combined_reasoning_content(reasoning_chunks)
] = processor.get_combined_reasoning_content(reasoning_chunks) )
audio_chunks = [ audio_chunks = [
chunk chunk