mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
should_run_prompt_management_hooks
This commit is contained in:
parent
d986b5d6b1
commit
e64254b381
2 changed files with 167 additions and 149 deletions
|
@ -249,9 +249,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.litellm_trace_id = litellm_trace_id
|
self.litellm_trace_id = litellm_trace_id
|
||||||
self.function_id = function_id
|
self.function_id = function_id
|
||||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||||
self.sync_streaming_chunks: List[
|
self.sync_streaming_chunks: List[Any] = (
|
||||||
Any
|
[]
|
||||||
] = [] # for generating complete stream response
|
) # for generating complete stream response
|
||||||
self.log_raw_request_response = log_raw_request_response
|
self.log_raw_request_response = log_raw_request_response
|
||||||
|
|
||||||
# Initialize dynamic callbacks
|
# Initialize dynamic callbacks
|
||||||
|
@ -455,6 +455,20 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
if "custom_llm_provider" in self.model_call_details:
|
if "custom_llm_provider" in self.model_call_details:
|
||||||
self.custom_llm_provider = self.model_call_details["custom_llm_provider"]
|
self.custom_llm_provider = self.model_call_details["custom_llm_provider"]
|
||||||
|
|
||||||
|
def should_run_prompt_management_hooks(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
kwargs: Dict,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Return True if prompt management hooks should be run
|
||||||
|
"""
|
||||||
|
if prompt_id:
|
||||||
|
return True
|
||||||
|
if kwargs.get("inject_cache_control_breakpoints_locations", None):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_chat_completion_prompt(
|
def get_chat_completion_prompt(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -557,9 +571,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
model
|
model
|
||||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||||
self.model_call_details["model"] = model
|
self.model_call_details["model"] = model
|
||||||
self.model_call_details["litellm_params"][
|
self.model_call_details["litellm_params"]["api_base"] = (
|
||||||
"api_base"
|
self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||||
] = self._get_masked_api_base(additional_args.get("api_base", ""))
|
)
|
||||||
|
|
||||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||||
# Log the exact input to the LLM API
|
# Log the exact input to the LLM API
|
||||||
|
@ -588,10 +602,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
try:
|
try:
|
||||||
# [Non-blocking Extra Debug Information in metadata]
|
# [Non-blocking Extra Debug Information in metadata]
|
||||||
if turn_off_message_logging is True:
|
if turn_off_message_logging is True:
|
||||||
_metadata[
|
_metadata["raw_request"] = (
|
||||||
"raw_request"
|
"redacted by litellm. \
|
||||||
] = "redacted by litellm. \
|
|
||||||
'litellm.turn_off_message_logging=True'"
|
'litellm.turn_off_message_logging=True'"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
curl_command = self._get_request_curl_command(
|
curl_command = self._get_request_curl_command(
|
||||||
api_base=additional_args.get("api_base", ""),
|
api_base=additional_args.get("api_base", ""),
|
||||||
|
@ -602,9 +616,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
|
|
||||||
_metadata["raw_request"] = str(curl_command)
|
_metadata["raw_request"] = str(curl_command)
|
||||||
# split up, so it's easier to parse in the UI
|
# split up, so it's easier to parse in the UI
|
||||||
self.model_call_details[
|
self.model_call_details["raw_request_typed_dict"] = (
|
||||||
"raw_request_typed_dict"
|
RawRequestTypedDict(
|
||||||
] = RawRequestTypedDict(
|
|
||||||
raw_request_api_base=str(
|
raw_request_api_base=str(
|
||||||
additional_args.get("api_base") or ""
|
additional_args.get("api_base") or ""
|
||||||
),
|
),
|
||||||
|
@ -617,18 +630,19 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
),
|
),
|
||||||
error=None,
|
error=None,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.model_call_details[
|
self.model_call_details["raw_request_typed_dict"] = (
|
||||||
"raw_request_typed_dict"
|
RawRequestTypedDict(
|
||||||
] = RawRequestTypedDict(
|
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
_metadata[
|
)
|
||||||
"raw_request"
|
_metadata["raw_request"] = (
|
||||||
] = "Unable to Log \
|
"Unable to Log \
|
||||||
raw request: {}".format(
|
raw request: {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if self.logger_fn and callable(self.logger_fn):
|
if self.logger_fn and callable(self.logger_fn):
|
||||||
try:
|
try:
|
||||||
self.logger_fn(
|
self.logger_fn(
|
||||||
|
@ -957,9 +971,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"response_cost_failure_debug_information: {debug_info}"
|
f"response_cost_failure_debug_information: {debug_info}"
|
||||||
)
|
)
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||||
"response_cost_failure_debug_information"
|
debug_info
|
||||||
] = debug_info
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -984,9 +998,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"response_cost_failure_debug_information: {debug_info}"
|
f"response_cost_failure_debug_information: {debug_info}"
|
||||||
)
|
)
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||||
"response_cost_failure_debug_information"
|
debug_info
|
||||||
] = debug_info
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1046,9 +1060,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
if self.completion_start_time is None:
|
if self.completion_start_time is None:
|
||||||
self.completion_start_time = end_time
|
self.completion_start_time = end_time
|
||||||
self.model_call_details[
|
self.model_call_details["completion_start_time"] = (
|
||||||
"completion_start_time"
|
self.completion_start_time
|
||||||
] = self.completion_start_time
|
)
|
||||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||||
self.model_call_details["end_time"] = end_time
|
self.model_call_details["end_time"] = end_time
|
||||||
self.model_call_details["cache_hit"] = cache_hit
|
self.model_call_details["cache_hit"] = cache_hit
|
||||||
|
@ -1127,14 +1141,13 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
"response_cost"
|
"response_cost"
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost"] = (
|
||||||
"response_cost"
|
self._response_cost_calculator(result=logging_result)
|
||||||
] = self._response_cost_calculator(result=logging_result)
|
)
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=logging_result,
|
init_response_obj=logging_result,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1143,11 +1156,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
status="success",
|
status="success",
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
elif isinstance(result, dict) or isinstance(result, list):
|
elif isinstance(result, dict) or isinstance(result, list):
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=result,
|
init_response_obj=result,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1156,10 +1169,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
status="success",
|
status="success",
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
elif standard_logging_object is not None:
|
elif standard_logging_object is not None:
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
standard_logging_object
|
||||||
] = standard_logging_object
|
)
|
||||||
else: # streaming chunks + image gen.
|
else: # streaming chunks + image gen.
|
||||||
self.model_call_details["response_cost"] = None
|
self.model_call_details["response_cost"] = None
|
||||||
|
|
||||||
|
@ -1215,16 +1229,15 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"Logging Details LiteLLM-Success Call streaming complete"
|
"Logging Details LiteLLM-Success Call streaming complete"
|
||||||
)
|
)
|
||||||
self.model_call_details[
|
self.model_call_details["complete_streaming_response"] = (
|
||||||
"complete_streaming_response"
|
complete_streaming_response
|
||||||
] = complete_streaming_response
|
)
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost"] = (
|
||||||
"response_cost"
|
self._response_cost_calculator(result=complete_streaming_response)
|
||||||
] = self._response_cost_calculator(result=complete_streaming_response)
|
)
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=complete_streaming_response,
|
init_response_obj=complete_streaming_response,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1233,6 +1246,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
status="success",
|
status="success",
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
callbacks = self.get_combined_callback_list(
|
callbacks = self.get_combined_callback_list(
|
||||||
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
||||||
global_callbacks=litellm.success_callback,
|
global_callbacks=litellm.success_callback,
|
||||||
|
@ -1580,11 +1594,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.stream and complete_streaming_response:
|
if self.stream and complete_streaming_response:
|
||||||
self.model_call_details[
|
self.model_call_details["complete_response"] = (
|
||||||
"complete_response"
|
self.model_call_details.get(
|
||||||
] = self.model_call_details.get(
|
|
||||||
"complete_streaming_response", {}
|
"complete_streaming_response", {}
|
||||||
)
|
)
|
||||||
|
)
|
||||||
result = self.model_call_details["complete_response"]
|
result = self.model_call_details["complete_response"]
|
||||||
openMeterLogger.log_success_event(
|
openMeterLogger.log_success_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
@ -1623,11 +1637,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.stream and complete_streaming_response:
|
if self.stream and complete_streaming_response:
|
||||||
self.model_call_details[
|
self.model_call_details["complete_response"] = (
|
||||||
"complete_response"
|
self.model_call_details.get(
|
||||||
] = self.model_call_details.get(
|
|
||||||
"complete_streaming_response", {}
|
"complete_streaming_response", {}
|
||||||
)
|
)
|
||||||
|
)
|
||||||
result = self.model_call_details["complete_response"]
|
result = self.model_call_details["complete_response"]
|
||||||
|
|
||||||
callback.log_success_event(
|
callback.log_success_event(
|
||||||
|
@ -1733,9 +1747,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||||
|
|
||||||
self.model_call_details[
|
self.model_call_details["async_complete_streaming_response"] = (
|
||||||
"async_complete_streaming_response"
|
complete_streaming_response
|
||||||
] = complete_streaming_response
|
)
|
||||||
try:
|
try:
|
||||||
if self.model_call_details.get("cache_hit", False) is True:
|
if self.model_call_details.get("cache_hit", False) is True:
|
||||||
self.model_call_details["response_cost"] = 0.0
|
self.model_call_details["response_cost"] = 0.0
|
||||||
|
@ -1745,11 +1759,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
model_call_details=self.model_call_details
|
model_call_details=self.model_call_details
|
||||||
)
|
)
|
||||||
# base_model defaults to None if not set on model_info
|
# base_model defaults to None if not set on model_info
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost"] = (
|
||||||
"response_cost"
|
self._response_cost_calculator(
|
||||||
] = self._response_cost_calculator(
|
|
||||||
result=complete_streaming_response
|
result=complete_streaming_response
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||||
|
@ -1761,9 +1775,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.model_call_details["response_cost"] = None
|
self.model_call_details["response_cost"] = None
|
||||||
|
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=complete_streaming_response,
|
init_response_obj=complete_streaming_response,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1772,6 +1785,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
status="success",
|
status="success",
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
callbacks = self.get_combined_callback_list(
|
callbacks = self.get_combined_callback_list(
|
||||||
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
||||||
global_callbacks=litellm._async_success_callback,
|
global_callbacks=litellm._async_success_callback,
|
||||||
|
@ -1976,9 +1990,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
|
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj={},
|
init_response_obj={},
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1989,6 +2002,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
original_exception=exception,
|
original_exception=exception,
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
return start_time, end_time
|
return start_time, end_time
|
||||||
|
|
||||||
async def special_failure_handlers(self, exception: Exception):
|
async def special_failure_handlers(self, exception: Exception):
|
||||||
|
@ -2753,9 +2767,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
endpoint=arize_config.endpoint,
|
endpoint=arize_config.endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ[
|
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
||||||
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
)
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if (
|
if (
|
||||||
isinstance(callback, ArizeLogger)
|
isinstance(callback, ArizeLogger)
|
||||||
|
@ -2779,9 +2793,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
|
|
||||||
# auth can be disabled on local deployments of arize phoenix
|
# auth can be disabled on local deployments of arize phoenix
|
||||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||||
os.environ[
|
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
arize_phoenix_config.otlp_auth_headers
|
||||||
] = arize_phoenix_config.otlp_auth_headers
|
)
|
||||||
|
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if (
|
if (
|
||||||
|
@ -2872,9 +2886,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
exporter="otlp_http",
|
exporter="otlp_http",
|
||||||
endpoint="https://langtrace.ai/api/trace",
|
endpoint="https://langtrace.ai/api/trace",
|
||||||
)
|
)
|
||||||
os.environ[
|
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||||
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
)
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if (
|
if (
|
||||||
isinstance(callback, OpenTelemetry)
|
isinstance(callback, OpenTelemetry)
|
||||||
|
@ -3369,11 +3383,11 @@ class StandardLoggingPayloadSetup:
|
||||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||||
if key in hidden_params:
|
if key in hidden_params:
|
||||||
if key == "additional_headers":
|
if key == "additional_headers":
|
||||||
clean_hidden_params[
|
clean_hidden_params["additional_headers"] = (
|
||||||
"additional_headers"
|
StandardLoggingPayloadSetup.get_additional_headers(
|
||||||
] = StandardLoggingPayloadSetup.get_additional_headers(
|
|
||||||
hidden_params[key]
|
hidden_params[key]
|
||||||
)
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
||||||
return clean_hidden_params
|
return clean_hidden_params
|
||||||
|
@ -3651,7 +3665,7 @@ def emit_standard_logging_payload(payload: StandardLoggingPayload):
|
||||||
|
|
||||||
|
|
||||||
def get_standard_logging_metadata(
|
def get_standard_logging_metadata(
|
||||||
metadata: Optional[Dict[str, Any]]
|
metadata: Optional[Dict[str, Any]],
|
||||||
) -> StandardLoggingMetadata:
|
) -> StandardLoggingMetadata:
|
||||||
"""
|
"""
|
||||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||||
|
@ -3715,9 +3729,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||||
):
|
):
|
||||||
for k, v in metadata["user_api_key_metadata"].items():
|
for k, v in metadata["user_api_key_metadata"].items():
|
||||||
if k == "logging": # prevent logging user logging keys
|
if k == "logging": # prevent logging user logging keys
|
||||||
cleaned_user_api_key_metadata[
|
cleaned_user_api_key_metadata[k] = (
|
||||||
k
|
"scrubbed_by_litellm_for_sensitive_keys"
|
||||||
] = "scrubbed_by_litellm_for_sensitive_keys"
|
)
|
||||||
else:
|
else:
|
||||||
cleaned_user_api_key_metadata[k] = v
|
cleaned_user_api_key_metadata[k] = v
|
||||||
|
|
||||||
|
|
|
@ -954,7 +954,11 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
||||||
litellm_params = {} # used to prevent unbound var errors
|
litellm_params = {} # used to prevent unbound var errors
|
||||||
## PROMPT MANAGEMENT HOOKS ##
|
## PROMPT MANAGEMENT HOOKS ##
|
||||||
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
|
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and (
|
||||||
|
litellm_logging_obj.should_run_prompt_management_hooks(
|
||||||
|
prompt_id=prompt_id, kwargs=kwargs
|
||||||
|
)
|
||||||
|
):
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
|
@ -2654,9 +2658,9 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
"aws_region_name" not in optional_params
|
"aws_region_name" not in optional_params
|
||||||
or optional_params["aws_region_name"] is None
|
or optional_params["aws_region_name"] is None
|
||||||
):
|
):
|
||||||
optional_params[
|
optional_params["aws_region_name"] = (
|
||||||
"aws_region_name"
|
aws_bedrock_client.meta.region_name
|
||||||
] = aws_bedrock_client.meta.region_name
|
)
|
||||||
|
|
||||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
if bedrock_route == "converse":
|
if bedrock_route == "converse":
|
||||||
|
@ -4363,9 +4367,9 @@ def adapter_completion(
|
||||||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||||
|
|
||||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||||
translated_response: Optional[
|
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
None
|
||||||
] = None
|
)
|
||||||
if isinstance(response, ModelResponse):
|
if isinstance(response, ModelResponse):
|
||||||
translated_response = translation_obj.translate_completion_output_params(
|
translated_response = translation_obj.translate_completion_output_params(
|
||||||
response=response
|
response=response
|
||||||
|
@ -5785,9 +5789,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(content_chunks) > 0:
|
if len(content_chunks) > 0:
|
||||||
response["choices"][0]["message"][
|
response["choices"][0]["message"]["content"] = (
|
||||||
"content"
|
processor.get_combined_content(content_chunks)
|
||||||
] = processor.get_combined_content(content_chunks)
|
)
|
||||||
|
|
||||||
reasoning_chunks = [
|
reasoning_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
@ -5798,9 +5802,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(reasoning_chunks) > 0:
|
if len(reasoning_chunks) > 0:
|
||||||
response["choices"][0]["message"][
|
response["choices"][0]["message"]["reasoning_content"] = (
|
||||||
"reasoning_content"
|
processor.get_combined_reasoning_content(reasoning_chunks)
|
||||||
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
)
|
||||||
|
|
||||||
audio_chunks = [
|
audio_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue