mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
should_run_prompt_management_hooks
This commit is contained in:
parent
d986b5d6b1
commit
e64254b381
2 changed files with 167 additions and 149 deletions
|
@ -249,9 +249,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.litellm_trace_id = litellm_trace_id
|
||||
self.function_id = function_id
|
||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[
|
||||
Any
|
||||
] = [] # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[Any] = (
|
||||
[]
|
||||
) # for generating complete stream response
|
||||
self.log_raw_request_response = log_raw_request_response
|
||||
|
||||
# Initialize dynamic callbacks
|
||||
|
@ -455,6 +455,20 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
if "custom_llm_provider" in self.model_call_details:
|
||||
self.custom_llm_provider = self.model_call_details["custom_llm_provider"]
|
||||
|
||||
def should_run_prompt_management_hooks(
|
||||
self,
|
||||
prompt_id: str,
|
||||
kwargs: Dict,
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if prompt management hooks should be run
|
||||
"""
|
||||
if prompt_id:
|
||||
return True
|
||||
if kwargs.get("inject_cache_control_breakpoints_locations", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -557,9 +571,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
model
|
||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
self.model_call_details["litellm_params"][
|
||||
"api_base"
|
||||
] = self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
self.model_call_details["litellm_params"]["api_base"] = (
|
||||
self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
)
|
||||
|
||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||
# Log the exact input to the LLM API
|
||||
|
@ -588,10 +602,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
if turn_off_message_logging is True:
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "redacted by litellm. \
|
||||
_metadata["raw_request"] = (
|
||||
"redacted by litellm. \
|
||||
'litellm.turn_off_message_logging=True'"
|
||||
)
|
||||
else:
|
||||
curl_command = self._get_request_curl_command(
|
||||
api_base=additional_args.get("api_base", ""),
|
||||
|
@ -602,32 +616,32 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
_metadata["raw_request"] = str(curl_command)
|
||||
# split up, so it's easier to parse in the UI
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
ignore_sensitive_headers=True,
|
||||
),
|
||||
error=None,
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
ignore_sensitive_headers=True,
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
error=str(e),
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "Unable to Log \
|
||||
_metadata["raw_request"] = (
|
||||
"Unable to Log \
|
||||
raw request: {}".format(
|
||||
str(e)
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if self.logger_fn and callable(self.logger_fn):
|
||||
try:
|
||||
|
@ -957,9 +971,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
|
@ -984,9 +998,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -1046,9 +1060,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
end_time = datetime.datetime.now()
|
||||
if self.completion_start_time is None:
|
||||
self.completion_start_time = end_time
|
||||
self.model_call_details[
|
||||
"completion_start_time"
|
||||
] = self.completion_start_time
|
||||
self.model_call_details["completion_start_time"] = (
|
||||
self.completion_start_time
|
||||
)
|
||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
self.model_call_details["cache_hit"] = cache_hit
|
||||
|
@ -1127,39 +1141,39 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
"response_cost"
|
||||
]
|
||||
else:
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=logging_result)
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=logging_result)
|
||||
)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=logging_result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=logging_result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
elif isinstance(result, dict) or isinstance(result, list):
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
elif standard_logging_object is not None:
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = standard_logging_object
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
standard_logging_object
|
||||
)
|
||||
else: # streaming chunks + image gen.
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
|
@ -1215,23 +1229,23 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
verbose_logger.debug(
|
||||
"Logging Details LiteLLM-Success Call streaming complete"
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=complete_streaming_response)
|
||||
self.model_call_details["complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=complete_streaming_response)
|
||||
)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
callbacks = self.get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
||||
|
@ -1580,10 +1594,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
openMeterLogger.log_success_event(
|
||||
|
@ -1623,10 +1637,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
|
||||
|
@ -1733,9 +1747,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
if complete_streaming_response is not None:
|
||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||
|
||||
self.model_call_details[
|
||||
"async_complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
self.model_call_details["async_complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
try:
|
||||
if self.model_call_details.get("cache_hit", False) is True:
|
||||
self.model_call_details["response_cost"] = 0.0
|
||||
|
@ -1745,10 +1759,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
|
@ -1761,16 +1775,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["response_cost"] = None
|
||||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
callbacks = self.get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
||||
|
@ -1976,18 +1990,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
return start_time, end_time
|
||||
|
||||
|
@ -2753,9 +2767,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
endpoint=arize_config.endpoint,
|
||||
)
|
||||
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
||||
)
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, ArizeLogger)
|
||||
|
@ -2779,9 +2793,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
|
||||
# auth can be disabled on local deployments of arize phoenix
|
||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = arize_phoenix_config.otlp_auth_headers
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
arize_phoenix_config.otlp_auth_headers
|
||||
)
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
|
@ -2872,9 +2886,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
exporter="otlp_http",
|
||||
endpoint="https://langtrace.ai/api/trace",
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
)
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, OpenTelemetry)
|
||||
|
@ -3369,10 +3383,10 @@ class StandardLoggingPayloadSetup:
|
|||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
if key in hidden_params:
|
||||
if key == "additional_headers":
|
||||
clean_hidden_params[
|
||||
"additional_headers"
|
||||
] = StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
clean_hidden_params["additional_headers"] = (
|
||||
StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
)
|
||||
)
|
||||
else:
|
||||
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
||||
|
@ -3651,7 +3665,7 @@ def emit_standard_logging_payload(payload: StandardLoggingPayload):
|
|||
|
||||
|
||||
def get_standard_logging_metadata(
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
) -> StandardLoggingMetadata:
|
||||
"""
|
||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||
|
@ -3715,9 +3729,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
|||
):
|
||||
for k, v in metadata["user_api_key_metadata"].items():
|
||||
if k == "logging": # prevent logging user logging keys
|
||||
cleaned_user_api_key_metadata[
|
||||
k
|
||||
] = "scrubbed_by_litellm_for_sensitive_keys"
|
||||
cleaned_user_api_key_metadata[k] = (
|
||||
"scrubbed_by_litellm_for_sensitive_keys"
|
||||
)
|
||||
else:
|
||||
cleaned_user_api_key_metadata[k] = v
|
||||
|
||||
|
|
|
@ -954,7 +954,11 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
||||
litellm_params = {} # used to prevent unbound var errors
|
||||
## PROMPT MANAGEMENT HOOKS ##
|
||||
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
|
||||
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and (
|
||||
litellm_logging_obj.should_run_prompt_management_hooks(
|
||||
prompt_id=prompt_id, kwargs=kwargs
|
||||
)
|
||||
):
|
||||
(
|
||||
model,
|
||||
messages,
|
||||
|
@ -2654,9 +2658,9 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"aws_region_name" not in optional_params
|
||||
or optional_params["aws_region_name"] is None
|
||||
):
|
||||
optional_params[
|
||||
"aws_region_name"
|
||||
] = aws_bedrock_client.meta.region_name
|
||||
optional_params["aws_region_name"] = (
|
||||
aws_bedrock_client.meta.region_name
|
||||
)
|
||||
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
if bedrock_route == "converse":
|
||||
|
@ -4363,9 +4367,9 @@ def adapter_completion(
|
|||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||
translated_response: Optional[
|
||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
||||
] = None
|
||||
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||
None
|
||||
)
|
||||
if isinstance(response, ModelResponse):
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
|
@ -5785,9 +5789,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(content_chunks) > 0:
|
||||
response["choices"][0]["message"][
|
||||
"content"
|
||||
] = processor.get_combined_content(content_chunks)
|
||||
response["choices"][0]["message"]["content"] = (
|
||||
processor.get_combined_content(content_chunks)
|
||||
)
|
||||
|
||||
reasoning_chunks = [
|
||||
chunk
|
||||
|
@ -5798,9 +5802,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(reasoning_chunks) > 0:
|
||||
response["choices"][0]["message"][
|
||||
"reasoning_content"
|
||||
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
response["choices"][0]["message"]["reasoning_content"] = (
|
||||
processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
)
|
||||
|
||||
audio_chunks = [
|
||||
chunk
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue