mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
[Feat] Add support for cache_control_injection_points
for Anthropic API, Bedrock API (#9996)
* test_anthropic_cache_control_hook_system_message * test_anthropic_cache_control_hook.py * should_run_prompt_management_hooks * fix should_run_prompt_management_hooks * test_anthropic_cache_control_hook_specific_index * fix test * fix linting errors * ChatCompletionCachedContent
This commit is contained in:
parent
d5004e3f24
commit
990fda294b
5 changed files with 453 additions and 149 deletions
118
litellm/integrations/anthropic_cache_control_hook.py
Normal file
118
litellm/integrations/anthropic_cache_control_hook.py
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
"""
|
||||||
|
This hook is used to inject cache control directives into the messages of a chat completion.
|
||||||
|
|
||||||
|
Users can define
|
||||||
|
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
|
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||||
|
from litellm.types.integrations.anthropic_cache_control_hook import (
|
||||||
|
CacheControlInjectionPoint,
|
||||||
|
CacheControlMessageInjectionPoint,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent
|
||||||
|
from litellm.types.utils import StandardCallbackDynamicParams
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicCacheControlHook(CustomPromptManagement):
|
||||||
|
def get_chat_completion_prompt(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
non_default_params: dict,
|
||||||
|
prompt_id: str,
|
||||||
|
prompt_variables: Optional[dict],
|
||||||
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
|
"""
|
||||||
|
Apply cache control directives based on specified injection points.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- model: str - the model to use
|
||||||
|
- messages: List[AllMessageValues] - messages with applied cache controls
|
||||||
|
- non_default_params: dict - params with any global cache controls
|
||||||
|
"""
|
||||||
|
# Extract cache control injection points
|
||||||
|
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop(
|
||||||
|
"cache_control_injection_points", []
|
||||||
|
)
|
||||||
|
if not injection_points:
|
||||||
|
return model, messages, non_default_params
|
||||||
|
|
||||||
|
# Create a deep copy of messages to avoid modifying the original list
|
||||||
|
processed_messages = copy.deepcopy(messages)
|
||||||
|
|
||||||
|
# Process message-level cache controls
|
||||||
|
for point in injection_points:
|
||||||
|
if point.get("location") == "message":
|
||||||
|
point = cast(CacheControlMessageInjectionPoint, point)
|
||||||
|
processed_messages = self._process_message_injection(
|
||||||
|
point=point, messages=processed_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, processed_messages, non_default_params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_message_injection(
|
||||||
|
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""Process message-level cache control injection."""
|
||||||
|
control: ChatCompletionCachedContent = point.get(
|
||||||
|
"control", None
|
||||||
|
) or ChatCompletionCachedContent(type="ephemeral")
|
||||||
|
targetted_index = point.get("index", None)
|
||||||
|
targetted_index = point.get("index", None)
|
||||||
|
targetted_role = point.get("role", None)
|
||||||
|
|
||||||
|
# Case 1: Target by specific index
|
||||||
|
if targetted_index is not None:
|
||||||
|
if 0 <= targetted_index < len(messages):
|
||||||
|
messages[targetted_index] = (
|
||||||
|
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||||
|
messages[targetted_index], control
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Case 2: Target by role
|
||||||
|
elif targetted_role is not None:
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == targetted_role:
|
||||||
|
msg = (
|
||||||
|
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||||
|
message=msg, control=control
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_insert_cache_control_in_message(
|
||||||
|
message: AllMessageValues, control: ChatCompletionCachedContent
|
||||||
|
) -> AllMessageValues:
|
||||||
|
"""
|
||||||
|
Safe way to insert cache control in a message
|
||||||
|
|
||||||
|
OpenAI Message content can be either:
|
||||||
|
- string
|
||||||
|
- list of objects
|
||||||
|
|
||||||
|
This method handles inserting cache control in both cases.
|
||||||
|
"""
|
||||||
|
message_content = message.get("content", None)
|
||||||
|
|
||||||
|
# 1. if string, insert cache control in the message
|
||||||
|
if isinstance(message_content, str):
|
||||||
|
message["cache_control"] = control # type: ignore
|
||||||
|
# 2. list of objects
|
||||||
|
elif isinstance(message_content, list):
|
||||||
|
for content_item in message_content:
|
||||||
|
if isinstance(content_item, dict):
|
||||||
|
content_item["cache_control"] = control # type: ignore
|
||||||
|
return message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def integration_name(self) -> str:
|
||||||
|
"""Return the integration name for this hook."""
|
||||||
|
return "anthropic-cache-control-hook"
|
|
@ -249,9 +249,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.litellm_trace_id = litellm_trace_id
|
self.litellm_trace_id = litellm_trace_id
|
||||||
self.function_id = function_id
|
self.function_id = function_id
|
||||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||||
self.sync_streaming_chunks: List[
|
self.sync_streaming_chunks: List[Any] = (
|
||||||
Any
|
[]
|
||||||
] = [] # for generating complete stream response
|
) # for generating complete stream response
|
||||||
self.log_raw_request_response = log_raw_request_response
|
self.log_raw_request_response = log_raw_request_response
|
||||||
|
|
||||||
# Initialize dynamic callbacks
|
# Initialize dynamic callbacks
|
||||||
|
@ -455,6 +455,20 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
if "custom_llm_provider" in self.model_call_details:
|
if "custom_llm_provider" in self.model_call_details:
|
||||||
self.custom_llm_provider = self.model_call_details["custom_llm_provider"]
|
self.custom_llm_provider = self.model_call_details["custom_llm_provider"]
|
||||||
|
|
||||||
|
def should_run_prompt_management_hooks(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
non_default_params: Dict,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Return True if prompt management hooks should be run
|
||||||
|
"""
|
||||||
|
if prompt_id:
|
||||||
|
return True
|
||||||
|
if non_default_params.get("cache_control_injection_points", None):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_chat_completion_prompt(
|
def get_chat_completion_prompt(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -557,9 +571,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
model
|
model
|
||||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||||
self.model_call_details["model"] = model
|
self.model_call_details["model"] = model
|
||||||
self.model_call_details["litellm_params"][
|
self.model_call_details["litellm_params"]["api_base"] = (
|
||||||
"api_base"
|
self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||||
] = self._get_masked_api_base(additional_args.get("api_base", ""))
|
)
|
||||||
|
|
||||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||||
# Log the exact input to the LLM API
|
# Log the exact input to the LLM API
|
||||||
|
@ -588,10 +602,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
try:
|
try:
|
||||||
# [Non-blocking Extra Debug Information in metadata]
|
# [Non-blocking Extra Debug Information in metadata]
|
||||||
if turn_off_message_logging is True:
|
if turn_off_message_logging is True:
|
||||||
_metadata[
|
_metadata["raw_request"] = (
|
||||||
"raw_request"
|
"redacted by litellm. \
|
||||||
] = "redacted by litellm. \
|
|
||||||
'litellm.turn_off_message_logging=True'"
|
'litellm.turn_off_message_logging=True'"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
curl_command = self._get_request_curl_command(
|
curl_command = self._get_request_curl_command(
|
||||||
api_base=additional_args.get("api_base", ""),
|
api_base=additional_args.get("api_base", ""),
|
||||||
|
@ -602,32 +616,32 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
|
|
||||||
_metadata["raw_request"] = str(curl_command)
|
_metadata["raw_request"] = str(curl_command)
|
||||||
# split up, so it's easier to parse in the UI
|
# split up, so it's easier to parse in the UI
|
||||||
self.model_call_details[
|
self.model_call_details["raw_request_typed_dict"] = (
|
||||||
"raw_request_typed_dict"
|
RawRequestTypedDict(
|
||||||
] = RawRequestTypedDict(
|
raw_request_api_base=str(
|
||||||
raw_request_api_base=str(
|
additional_args.get("api_base") or ""
|
||||||
additional_args.get("api_base") or ""
|
),
|
||||||
),
|
raw_request_body=self._get_raw_request_body(
|
||||||
raw_request_body=self._get_raw_request_body(
|
additional_args.get("complete_input_dict", {})
|
||||||
additional_args.get("complete_input_dict", {})
|
),
|
||||||
),
|
raw_request_headers=self._get_masked_headers(
|
||||||
raw_request_headers=self._get_masked_headers(
|
additional_args.get("headers", {}) or {},
|
||||||
additional_args.get("headers", {}) or {},
|
ignore_sensitive_headers=True,
|
||||||
ignore_sensitive_headers=True,
|
),
|
||||||
),
|
error=None,
|
||||||
error=None,
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.model_call_details[
|
self.model_call_details["raw_request_typed_dict"] = (
|
||||||
"raw_request_typed_dict"
|
RawRequestTypedDict(
|
||||||
] = RawRequestTypedDict(
|
error=str(e),
|
||||||
error=str(e),
|
)
|
||||||
)
|
)
|
||||||
_metadata[
|
_metadata["raw_request"] = (
|
||||||
"raw_request"
|
"Unable to Log \
|
||||||
] = "Unable to Log \
|
|
||||||
raw request: {}".format(
|
raw request: {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if self.logger_fn and callable(self.logger_fn):
|
if self.logger_fn and callable(self.logger_fn):
|
||||||
try:
|
try:
|
||||||
|
@ -957,9 +971,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"response_cost_failure_debug_information: {debug_info}"
|
f"response_cost_failure_debug_information: {debug_info}"
|
||||||
)
|
)
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||||
"response_cost_failure_debug_information"
|
debug_info
|
||||||
] = debug_info
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -984,9 +998,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"response_cost_failure_debug_information: {debug_info}"
|
f"response_cost_failure_debug_information: {debug_info}"
|
||||||
)
|
)
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||||
"response_cost_failure_debug_information"
|
debug_info
|
||||||
] = debug_info
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1046,9 +1060,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
if self.completion_start_time is None:
|
if self.completion_start_time is None:
|
||||||
self.completion_start_time = end_time
|
self.completion_start_time = end_time
|
||||||
self.model_call_details[
|
self.model_call_details["completion_start_time"] = (
|
||||||
"completion_start_time"
|
self.completion_start_time
|
||||||
] = self.completion_start_time
|
)
|
||||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||||
self.model_call_details["end_time"] = end_time
|
self.model_call_details["end_time"] = end_time
|
||||||
self.model_call_details["cache_hit"] = cache_hit
|
self.model_call_details["cache_hit"] = cache_hit
|
||||||
|
@ -1127,39 +1141,39 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
"response_cost"
|
"response_cost"
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost"] = (
|
||||||
"response_cost"
|
self._response_cost_calculator(result=logging_result)
|
||||||
] = self._response_cost_calculator(result=logging_result)
|
)
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
kwargs=self.model_call_details,
|
||||||
kwargs=self.model_call_details,
|
init_response_obj=logging_result,
|
||||||
init_response_obj=logging_result,
|
start_time=start_time,
|
||||||
start_time=start_time,
|
end_time=end_time,
|
||||||
end_time=end_time,
|
logging_obj=self,
|
||||||
logging_obj=self,
|
status="success",
|
||||||
status="success",
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(result, dict) or isinstance(result, list):
|
elif isinstance(result, dict) or isinstance(result, list):
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
kwargs=self.model_call_details,
|
||||||
kwargs=self.model_call_details,
|
init_response_obj=result,
|
||||||
init_response_obj=result,
|
start_time=start_time,
|
||||||
start_time=start_time,
|
end_time=end_time,
|
||||||
end_time=end_time,
|
logging_obj=self,
|
||||||
logging_obj=self,
|
status="success",
|
||||||
status="success",
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
)
|
||||||
)
|
)
|
||||||
elif standard_logging_object is not None:
|
elif standard_logging_object is not None:
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
standard_logging_object
|
||||||
] = standard_logging_object
|
)
|
||||||
else: # streaming chunks + image gen.
|
else: # streaming chunks + image gen.
|
||||||
self.model_call_details["response_cost"] = None
|
self.model_call_details["response_cost"] = None
|
||||||
|
|
||||||
|
@ -1215,23 +1229,23 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"Logging Details LiteLLM-Success Call streaming complete"
|
"Logging Details LiteLLM-Success Call streaming complete"
|
||||||
)
|
)
|
||||||
self.model_call_details[
|
self.model_call_details["complete_streaming_response"] = (
|
||||||
"complete_streaming_response"
|
complete_streaming_response
|
||||||
] = complete_streaming_response
|
)
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost"] = (
|
||||||
"response_cost"
|
self._response_cost_calculator(result=complete_streaming_response)
|
||||||
] = self._response_cost_calculator(result=complete_streaming_response)
|
)
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
kwargs=self.model_call_details,
|
||||||
kwargs=self.model_call_details,
|
init_response_obj=complete_streaming_response,
|
||||||
init_response_obj=complete_streaming_response,
|
start_time=start_time,
|
||||||
start_time=start_time,
|
end_time=end_time,
|
||||||
end_time=end_time,
|
logging_obj=self,
|
||||||
logging_obj=self,
|
status="success",
|
||||||
status="success",
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
)
|
||||||
)
|
)
|
||||||
callbacks = self.get_combined_callback_list(
|
callbacks = self.get_combined_callback_list(
|
||||||
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
||||||
|
@ -1580,10 +1594,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.stream and complete_streaming_response:
|
if self.stream and complete_streaming_response:
|
||||||
self.model_call_details[
|
self.model_call_details["complete_response"] = (
|
||||||
"complete_response"
|
self.model_call_details.get(
|
||||||
] = self.model_call_details.get(
|
"complete_streaming_response", {}
|
||||||
"complete_streaming_response", {}
|
)
|
||||||
)
|
)
|
||||||
result = self.model_call_details["complete_response"]
|
result = self.model_call_details["complete_response"]
|
||||||
openMeterLogger.log_success_event(
|
openMeterLogger.log_success_event(
|
||||||
|
@ -1623,10 +1637,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.stream and complete_streaming_response:
|
if self.stream and complete_streaming_response:
|
||||||
self.model_call_details[
|
self.model_call_details["complete_response"] = (
|
||||||
"complete_response"
|
self.model_call_details.get(
|
||||||
] = self.model_call_details.get(
|
"complete_streaming_response", {}
|
||||||
"complete_streaming_response", {}
|
)
|
||||||
)
|
)
|
||||||
result = self.model_call_details["complete_response"]
|
result = self.model_call_details["complete_response"]
|
||||||
|
|
||||||
|
@ -1733,9 +1747,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||||
|
|
||||||
self.model_call_details[
|
self.model_call_details["async_complete_streaming_response"] = (
|
||||||
"async_complete_streaming_response"
|
complete_streaming_response
|
||||||
] = complete_streaming_response
|
)
|
||||||
try:
|
try:
|
||||||
if self.model_call_details.get("cache_hit", False) is True:
|
if self.model_call_details.get("cache_hit", False) is True:
|
||||||
self.model_call_details["response_cost"] = 0.0
|
self.model_call_details["response_cost"] = 0.0
|
||||||
|
@ -1745,10 +1759,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
model_call_details=self.model_call_details
|
model_call_details=self.model_call_details
|
||||||
)
|
)
|
||||||
# base_model defaults to None if not set on model_info
|
# base_model defaults to None if not set on model_info
|
||||||
self.model_call_details[
|
self.model_call_details["response_cost"] = (
|
||||||
"response_cost"
|
self._response_cost_calculator(
|
||||||
] = self._response_cost_calculator(
|
result=complete_streaming_response
|
||||||
result=complete_streaming_response
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -1761,16 +1775,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.model_call_details["response_cost"] = None
|
self.model_call_details["response_cost"] = None
|
||||||
|
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
kwargs=self.model_call_details,
|
||||||
kwargs=self.model_call_details,
|
init_response_obj=complete_streaming_response,
|
||||||
init_response_obj=complete_streaming_response,
|
start_time=start_time,
|
||||||
start_time=start_time,
|
end_time=end_time,
|
||||||
end_time=end_time,
|
logging_obj=self,
|
||||||
logging_obj=self,
|
status="success",
|
||||||
status="success",
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
)
|
||||||
)
|
)
|
||||||
callbacks = self.get_combined_callback_list(
|
callbacks = self.get_combined_callback_list(
|
||||||
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
||||||
|
@ -1976,18 +1990,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
|
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
self.model_call_details[
|
self.model_call_details["standard_logging_object"] = (
|
||||||
"standard_logging_object"
|
get_standard_logging_object_payload(
|
||||||
] = get_standard_logging_object_payload(
|
kwargs=self.model_call_details,
|
||||||
kwargs=self.model_call_details,
|
init_response_obj={},
|
||||||
init_response_obj={},
|
start_time=start_time,
|
||||||
start_time=start_time,
|
end_time=end_time,
|
||||||
end_time=end_time,
|
logging_obj=self,
|
||||||
logging_obj=self,
|
status="failure",
|
||||||
status="failure",
|
error_str=str(exception),
|
||||||
error_str=str(exception),
|
original_exception=exception,
|
||||||
original_exception=exception,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
)
|
||||||
)
|
)
|
||||||
return start_time, end_time
|
return start_time, end_time
|
||||||
|
|
||||||
|
@ -2753,9 +2767,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
endpoint=arize_config.endpoint,
|
endpoint=arize_config.endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ[
|
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
||||||
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
)
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if (
|
if (
|
||||||
isinstance(callback, ArizeLogger)
|
isinstance(callback, ArizeLogger)
|
||||||
|
@ -2779,9 +2793,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
|
|
||||||
# auth can be disabled on local deployments of arize phoenix
|
# auth can be disabled on local deployments of arize phoenix
|
||||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||||
os.environ[
|
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
arize_phoenix_config.otlp_auth_headers
|
||||||
] = arize_phoenix_config.otlp_auth_headers
|
)
|
||||||
|
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if (
|
if (
|
||||||
|
@ -2872,9 +2886,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
exporter="otlp_http",
|
exporter="otlp_http",
|
||||||
endpoint="https://langtrace.ai/api/trace",
|
endpoint="https://langtrace.ai/api/trace",
|
||||||
)
|
)
|
||||||
os.environ[
|
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||||
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
)
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if (
|
if (
|
||||||
isinstance(callback, OpenTelemetry)
|
isinstance(callback, OpenTelemetry)
|
||||||
|
@ -3369,10 +3383,10 @@ class StandardLoggingPayloadSetup:
|
||||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||||
if key in hidden_params:
|
if key in hidden_params:
|
||||||
if key == "additional_headers":
|
if key == "additional_headers":
|
||||||
clean_hidden_params[
|
clean_hidden_params["additional_headers"] = (
|
||||||
"additional_headers"
|
StandardLoggingPayloadSetup.get_additional_headers(
|
||||||
] = StandardLoggingPayloadSetup.get_additional_headers(
|
hidden_params[key]
|
||||||
hidden_params[key]
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
||||||
|
@ -3651,7 +3665,7 @@ def emit_standard_logging_payload(payload: StandardLoggingPayload):
|
||||||
|
|
||||||
|
|
||||||
def get_standard_logging_metadata(
|
def get_standard_logging_metadata(
|
||||||
metadata: Optional[Dict[str, Any]]
|
metadata: Optional[Dict[str, Any]],
|
||||||
) -> StandardLoggingMetadata:
|
) -> StandardLoggingMetadata:
|
||||||
"""
|
"""
|
||||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||||
|
@ -3715,9 +3729,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||||
):
|
):
|
||||||
for k, v in metadata["user_api_key_metadata"].items():
|
for k, v in metadata["user_api_key_metadata"].items():
|
||||||
if k == "logging": # prevent logging user logging keys
|
if k == "logging": # prevent logging user logging keys
|
||||||
cleaned_user_api_key_metadata[
|
cleaned_user_api_key_metadata[k] = (
|
||||||
k
|
"scrubbed_by_litellm_for_sensitive_keys"
|
||||||
] = "scrubbed_by_litellm_for_sensitive_keys"
|
)
|
||||||
else:
|
else:
|
||||||
cleaned_user_api_key_metadata[k] = v
|
cleaned_user_api_key_metadata[k] = v
|
||||||
|
|
||||||
|
|
|
@ -954,7 +954,11 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
||||||
litellm_params = {} # used to prevent unbound var errors
|
litellm_params = {} # used to prevent unbound var errors
|
||||||
## PROMPT MANAGEMENT HOOKS ##
|
## PROMPT MANAGEMENT HOOKS ##
|
||||||
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
|
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and (
|
||||||
|
litellm_logging_obj.should_run_prompt_management_hooks(
|
||||||
|
prompt_id=prompt_id, non_default_params=non_default_params
|
||||||
|
)
|
||||||
|
):
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
|
@ -2654,9 +2658,9 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
"aws_region_name" not in optional_params
|
"aws_region_name" not in optional_params
|
||||||
or optional_params["aws_region_name"] is None
|
or optional_params["aws_region_name"] is None
|
||||||
):
|
):
|
||||||
optional_params[
|
optional_params["aws_region_name"] = (
|
||||||
"aws_region_name"
|
aws_bedrock_client.meta.region_name
|
||||||
] = aws_bedrock_client.meta.region_name
|
)
|
||||||
|
|
||||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
if bedrock_route == "converse":
|
if bedrock_route == "converse":
|
||||||
|
@ -4363,9 +4367,9 @@ def adapter_completion(
|
||||||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||||
|
|
||||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||||
translated_response: Optional[
|
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
None
|
||||||
] = None
|
)
|
||||||
if isinstance(response, ModelResponse):
|
if isinstance(response, ModelResponse):
|
||||||
translated_response = translation_obj.translate_completion_output_params(
|
translated_response = translation_obj.translate_completion_output_params(
|
||||||
response=response
|
response=response
|
||||||
|
@ -5785,9 +5789,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(content_chunks) > 0:
|
if len(content_chunks) > 0:
|
||||||
response["choices"][0]["message"][
|
response["choices"][0]["message"]["content"] = (
|
||||||
"content"
|
processor.get_combined_content(content_chunks)
|
||||||
] = processor.get_combined_content(content_chunks)
|
)
|
||||||
|
|
||||||
reasoning_chunks = [
|
reasoning_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
@ -5798,9 +5802,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(reasoning_chunks) > 0:
|
if len(reasoning_chunks) > 0:
|
||||||
response["choices"][0]["message"][
|
response["choices"][0]["message"]["reasoning_content"] = (
|
||||||
"reasoning_content"
|
processor.get_combined_reasoning_content(reasoning_chunks)
|
||||||
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
)
|
||||||
|
|
||||||
audio_chunks = [
|
audio_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
|
17
litellm/types/integrations/anthropic_cache_control_hook.py
Normal file
17
litellm/types/integrations/anthropic_cache_control_hook.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
from typing import Literal, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
from litellm.types.llms.openai import ChatCompletionCachedContent
|
||||||
|
|
||||||
|
|
||||||
|
class CacheControlMessageInjectionPoint(TypedDict):
|
||||||
|
"""Type for message-level injection points."""
|
||||||
|
|
||||||
|
location: Literal["message"]
|
||||||
|
role: Optional[
|
||||||
|
Literal["user", "system", "assistant"]
|
||||||
|
] # Optional: target by role (user, system, assistant)
|
||||||
|
index: Optional[int] # Optional: target by specific index
|
||||||
|
control: Optional[ChatCompletionCachedContent]
|
||||||
|
|
||||||
|
|
||||||
|
CacheControlInjectionPoint = CacheControlMessageInjectionPoint
|
151
tests/litellm/integrations/test_anthropic_cache_control_hook.py
Normal file
151
tests/litellm/integrations/test_anthropic_cache_control_hook.py
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
from unittest.mock import ANY, MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system-path
|
||||||
|
import litellm
|
||||||
|
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import StandardCallbackDynamicParams
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_cache_control_hook_system_message():
|
||||||
|
anthropic_cache_control_hook = AnthropicCacheControlHook()
|
||||||
|
litellm.callbacks = [anthropic_cache_control_hook]
|
||||||
|
|
||||||
|
# Mock response data
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Here is my analysis of the key terms and conditions...",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stopReason": "stop_sequence",
|
||||||
|
"usage": {
|
||||||
|
"inputTokens": 100,
|
||||||
|
"outputTokens": 200,
|
||||||
|
"totalTokens": 300,
|
||||||
|
"cacheReadInputTokens": 100,
|
||||||
|
"cacheWriteInputTokens": 200,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
# Mock AsyncHTTPHandler.post method
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are an AI assistant tasked with analyzing legal documents.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what are the key terms and conditions in this agreement?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
cache_control_injection_points=[
|
||||||
|
{
|
||||||
|
"location": "message",
|
||||||
|
"role": "system",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
|
||||||
|
print("request_body: ", json.dumps(request_body, indent=4))
|
||||||
|
|
||||||
|
# Verify the request body
|
||||||
|
assert request_body["system"][1]["cachePoint"] == {"type": "default"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_cache_control_hook_user_message():
|
||||||
|
anthropic_cache_control_hook = AnthropicCacheControlHook()
|
||||||
|
litellm.callbacks = [anthropic_cache_control_hook]
|
||||||
|
|
||||||
|
# Mock response data
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Here is my analysis of the key terms and conditions...",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stopReason": "stop_sequence",
|
||||||
|
"usage": {
|
||||||
|
"inputTokens": 100,
|
||||||
|
"outputTokens": 200,
|
||||||
|
"totalTokens": 300,
|
||||||
|
"cacheReadInputTokens": 100,
|
||||||
|
"cacheWriteInputTokens": 200,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
# Mock AsyncHTTPHandler.post method
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are an AI assistant tasked with analyzing legal documents.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what are the key terms and conditions in this agreement? <very_long_text>",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
cache_control_injection_points=[
|
||||||
|
{
|
||||||
|
"location": "message",
|
||||||
|
"role": "user",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
|
||||||
|
print("request_body: ", json.dumps(request_body, indent=4))
|
||||||
|
|
||||||
|
# Verify the request body
|
||||||
|
assert request_body["messages"][1]["content"][1]["cachePoint"] == {
|
||||||
|
"type": "default"
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue