mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
[Feat] Add support for cache_control_injection_points
for Anthropic API, Bedrock API (#9996)
* test_anthropic_cache_control_hook_system_message * test_anthropic_cache_control_hook.py * should_run_prompt_management_hooks * fix should_run_prompt_management_hooks * test_anthropic_cache_control_hook_specific_index * fix test * fix linting errors * ChatCompletionCachedContent
This commit is contained in:
parent
2ed593e052
commit
6cfa50d278
5 changed files with 453 additions and 149 deletions
118
litellm/integrations/anthropic_cache_control_hook.py
Normal file
118
litellm/integrations/anthropic_cache_control_hook.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
This hook is used to inject cache control directives into the messages of a chat completion.
|
||||
|
||||
Users can define
|
||||
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points.
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.types.integrations.anthropic_cache_control_hook import (
|
||||
CacheControlInjectionPoint,
|
||||
CacheControlMessageInjectionPoint,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
class AnthropicCacheControlHook(CustomPromptManagement):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Apply cache control directives based on specified injection points.
|
||||
|
||||
Returns:
|
||||
- model: str - the model to use
|
||||
- messages: List[AllMessageValues] - messages with applied cache controls
|
||||
- non_default_params: dict - params with any global cache controls
|
||||
"""
|
||||
# Extract cache control injection points
|
||||
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop(
|
||||
"cache_control_injection_points", []
|
||||
)
|
||||
if not injection_points:
|
||||
return model, messages, non_default_params
|
||||
|
||||
# Create a deep copy of messages to avoid modifying the original list
|
||||
processed_messages = copy.deepcopy(messages)
|
||||
|
||||
# Process message-level cache controls
|
||||
for point in injection_points:
|
||||
if point.get("location") == "message":
|
||||
point = cast(CacheControlMessageInjectionPoint, point)
|
||||
processed_messages = self._process_message_injection(
|
||||
point=point, messages=processed_messages
|
||||
)
|
||||
|
||||
return model, processed_messages, non_default_params
|
||||
|
||||
@staticmethod
|
||||
def _process_message_injection(
|
||||
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""Process message-level cache control injection."""
|
||||
control: ChatCompletionCachedContent = point.get(
|
||||
"control", None
|
||||
) or ChatCompletionCachedContent(type="ephemeral")
|
||||
targetted_index = point.get("index", None)
|
||||
targetted_index = point.get("index", None)
|
||||
targetted_role = point.get("role", None)
|
||||
|
||||
# Case 1: Target by specific index
|
||||
if targetted_index is not None:
|
||||
if 0 <= targetted_index < len(messages):
|
||||
messages[targetted_index] = (
|
||||
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
messages[targetted_index], control
|
||||
)
|
||||
)
|
||||
# Case 2: Target by role
|
||||
elif targetted_role is not None:
|
||||
for msg in messages:
|
||||
if msg.get("role") == targetted_role:
|
||||
msg = (
|
||||
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
message=msg, control=control
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _safe_insert_cache_control_in_message(
|
||||
message: AllMessageValues, control: ChatCompletionCachedContent
|
||||
) -> AllMessageValues:
|
||||
"""
|
||||
Safe way to insert cache control in a message
|
||||
|
||||
OpenAI Message content can be either:
|
||||
- string
|
||||
- list of objects
|
||||
|
||||
This method handles inserting cache control in both cases.
|
||||
"""
|
||||
message_content = message.get("content", None)
|
||||
|
||||
# 1. if string, insert cache control in the message
|
||||
if isinstance(message_content, str):
|
||||
message["cache_control"] = control # type: ignore
|
||||
# 2. list of objects
|
||||
elif isinstance(message_content, list):
|
||||
for content_item in message_content:
|
||||
if isinstance(content_item, dict):
|
||||
content_item["cache_control"] = control # type: ignore
|
||||
return message
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Return the integration name for this hook."""
|
||||
return "anthropic-cache-control-hook"
|
|
@ -249,9 +249,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.litellm_trace_id = litellm_trace_id
|
||||
self.function_id = function_id
|
||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[
|
||||
Any
|
||||
] = [] # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[Any] = (
|
||||
[]
|
||||
) # for generating complete stream response
|
||||
self.log_raw_request_response = log_raw_request_response
|
||||
|
||||
# Initialize dynamic callbacks
|
||||
|
@ -455,6 +455,20 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
if "custom_llm_provider" in self.model_call_details:
|
||||
self.custom_llm_provider = self.model_call_details["custom_llm_provider"]
|
||||
|
||||
def should_run_prompt_management_hooks(
|
||||
self,
|
||||
prompt_id: str,
|
||||
non_default_params: Dict,
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if prompt management hooks should be run
|
||||
"""
|
||||
if prompt_id:
|
||||
return True
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -557,9 +571,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
model
|
||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
self.model_call_details["litellm_params"][
|
||||
"api_base"
|
||||
] = self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
self.model_call_details["litellm_params"]["api_base"] = (
|
||||
self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
)
|
||||
|
||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||
# Log the exact input to the LLM API
|
||||
|
@ -588,10 +602,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
if turn_off_message_logging is True:
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "redacted by litellm. \
|
||||
_metadata["raw_request"] = (
|
||||
"redacted by litellm. \
|
||||
'litellm.turn_off_message_logging=True'"
|
||||
)
|
||||
else:
|
||||
curl_command = self._get_request_curl_command(
|
||||
api_base=additional_args.get("api_base", ""),
|
||||
|
@ -602,32 +616,32 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
_metadata["raw_request"] = str(curl_command)
|
||||
# split up, so it's easier to parse in the UI
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
ignore_sensitive_headers=True,
|
||||
),
|
||||
error=None,
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
ignore_sensitive_headers=True,
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
error=str(e),
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "Unable to Log \
|
||||
_metadata["raw_request"] = (
|
||||
"Unable to Log \
|
||||
raw request: {}".format(
|
||||
str(e)
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if self.logger_fn and callable(self.logger_fn):
|
||||
try:
|
||||
|
@ -957,9 +971,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
|
@ -984,9 +998,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -1046,9 +1060,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
end_time = datetime.datetime.now()
|
||||
if self.completion_start_time is None:
|
||||
self.completion_start_time = end_time
|
||||
self.model_call_details[
|
||||
"completion_start_time"
|
||||
] = self.completion_start_time
|
||||
self.model_call_details["completion_start_time"] = (
|
||||
self.completion_start_time
|
||||
)
|
||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
self.model_call_details["cache_hit"] = cache_hit
|
||||
|
@ -1127,39 +1141,39 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
"response_cost"
|
||||
]
|
||||
else:
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=logging_result)
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=logging_result)
|
||||
)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=logging_result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=logging_result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
elif isinstance(result, dict) or isinstance(result, list):
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
elif standard_logging_object is not None:
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = standard_logging_object
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
standard_logging_object
|
||||
)
|
||||
else: # streaming chunks + image gen.
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
|
@ -1215,23 +1229,23 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
verbose_logger.debug(
|
||||
"Logging Details LiteLLM-Success Call streaming complete"
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=complete_streaming_response)
|
||||
self.model_call_details["complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=complete_streaming_response)
|
||||
)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
callbacks = self.get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
||||
|
@ -1580,10 +1594,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
openMeterLogger.log_success_event(
|
||||
|
@ -1623,10 +1637,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
|
||||
|
@ -1733,9 +1747,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
if complete_streaming_response is not None:
|
||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||
|
||||
self.model_call_details[
|
||||
"async_complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
self.model_call_details["async_complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
try:
|
||||
if self.model_call_details.get("cache_hit", False) is True:
|
||||
self.model_call_details["response_cost"] = 0.0
|
||||
|
@ -1745,10 +1759,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
|
@ -1761,16 +1775,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["response_cost"] = None
|
||||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
callbacks = self.get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
||||
|
@ -1976,18 +1990,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
return start_time, end_time
|
||||
|
||||
|
@ -2753,9 +2767,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
endpoint=arize_config.endpoint,
|
||||
)
|
||||
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
||||
)
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, ArizeLogger)
|
||||
|
@ -2779,9 +2793,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
|
||||
# auth can be disabled on local deployments of arize phoenix
|
||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = arize_phoenix_config.otlp_auth_headers
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
arize_phoenix_config.otlp_auth_headers
|
||||
)
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
|
@ -2872,9 +2886,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
exporter="otlp_http",
|
||||
endpoint="https://langtrace.ai/api/trace",
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
)
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, OpenTelemetry)
|
||||
|
@ -3369,10 +3383,10 @@ class StandardLoggingPayloadSetup:
|
|||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
if key in hidden_params:
|
||||
if key == "additional_headers":
|
||||
clean_hidden_params[
|
||||
"additional_headers"
|
||||
] = StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
clean_hidden_params["additional_headers"] = (
|
||||
StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
)
|
||||
)
|
||||
else:
|
||||
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
||||
|
@ -3651,7 +3665,7 @@ def emit_standard_logging_payload(payload: StandardLoggingPayload):
|
|||
|
||||
|
||||
def get_standard_logging_metadata(
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
) -> StandardLoggingMetadata:
|
||||
"""
|
||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||
|
@ -3715,9 +3729,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
|||
):
|
||||
for k, v in metadata["user_api_key_metadata"].items():
|
||||
if k == "logging": # prevent logging user logging keys
|
||||
cleaned_user_api_key_metadata[
|
||||
k
|
||||
] = "scrubbed_by_litellm_for_sensitive_keys"
|
||||
cleaned_user_api_key_metadata[k] = (
|
||||
"scrubbed_by_litellm_for_sensitive_keys"
|
||||
)
|
||||
else:
|
||||
cleaned_user_api_key_metadata[k] = v
|
||||
|
||||
|
|
|
@ -954,7 +954,11 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
||||
litellm_params = {} # used to prevent unbound var errors
|
||||
## PROMPT MANAGEMENT HOOKS ##
|
||||
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
|
||||
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and (
|
||||
litellm_logging_obj.should_run_prompt_management_hooks(
|
||||
prompt_id=prompt_id, non_default_params=non_default_params
|
||||
)
|
||||
):
|
||||
(
|
||||
model,
|
||||
messages,
|
||||
|
@ -2654,9 +2658,9 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"aws_region_name" not in optional_params
|
||||
or optional_params["aws_region_name"] is None
|
||||
):
|
||||
optional_params[
|
||||
"aws_region_name"
|
||||
] = aws_bedrock_client.meta.region_name
|
||||
optional_params["aws_region_name"] = (
|
||||
aws_bedrock_client.meta.region_name
|
||||
)
|
||||
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
if bedrock_route == "converse":
|
||||
|
@ -4363,9 +4367,9 @@ def adapter_completion(
|
|||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||
translated_response: Optional[
|
||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
||||
] = None
|
||||
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||
None
|
||||
)
|
||||
if isinstance(response, ModelResponse):
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
|
@ -5785,9 +5789,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(content_chunks) > 0:
|
||||
response["choices"][0]["message"][
|
||||
"content"
|
||||
] = processor.get_combined_content(content_chunks)
|
||||
response["choices"][0]["message"]["content"] = (
|
||||
processor.get_combined_content(content_chunks)
|
||||
)
|
||||
|
||||
reasoning_chunks = [
|
||||
chunk
|
||||
|
@ -5798,9 +5802,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(reasoning_chunks) > 0:
|
||||
response["choices"][0]["message"][
|
||||
"reasoning_content"
|
||||
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
response["choices"][0]["message"]["reasoning_content"] = (
|
||||
processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
)
|
||||
|
||||
audio_chunks = [
|
||||
chunk
|
||||
|
|
17
litellm/types/integrations/anthropic_cache_control_hook.py
Normal file
17
litellm/types/integrations/anthropic_cache_control_hook.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
from litellm.types.llms.openai import ChatCompletionCachedContent
|
||||
|
||||
|
||||
class CacheControlMessageInjectionPoint(TypedDict):
|
||||
"""Type for message-level injection points."""
|
||||
|
||||
location: Literal["message"]
|
||||
role: Optional[
|
||||
Literal["user", "system", "assistant"]
|
||||
] # Optional: target by role (user, system, assistant)
|
||||
index: Optional[int] # Optional: target by specific index
|
||||
control: Optional[ChatCompletionCachedContent]
|
||||
|
||||
|
||||
CacheControlInjectionPoint = CacheControlMessageInjectionPoint
|
151
tests/litellm/integrations/test_anthropic_cache_control_hook.py
Normal file
151
tests/litellm/integrations/test_anthropic_cache_control_hook.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import ANY, MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
import litellm
|
||||
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_cache_control_hook_system_message():
|
||||
anthropic_cache_control_hook = AnthropicCacheControlHook()
|
||||
litellm.callbacks = [anthropic_cache_control_hook]
|
||||
|
||||
# Mock response data
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Here is my analysis of the key terms and conditions...",
|
||||
}
|
||||
},
|
||||
"stopReason": "stop_sequence",
|
||||
"usage": {
|
||||
"inputTokens": 100,
|
||||
"outputTokens": 200,
|
||||
"totalTokens": 300,
|
||||
"cacheReadInputTokens": 100,
|
||||
"cacheWriteInputTokens": 200,
|
||||
},
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
|
||||
# Mock AsyncHTTPHandler.post method
|
||||
client = AsyncHTTPHandler()
|
||||
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||
response = await litellm.acompletion(
|
||||
model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an AI assistant tasked with analyzing legal documents.",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Here is the full text of a complex legal agreement",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what are the key terms and conditions in this agreement?",
|
||||
},
|
||||
],
|
||||
cache_control_injection_points=[
|
||||
{
|
||||
"location": "message",
|
||||
"role": "system",
|
||||
},
|
||||
],
|
||||
client=client,
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||
|
||||
print("request_body: ", json.dumps(request_body, indent=4))
|
||||
|
||||
# Verify the request body
|
||||
assert request_body["system"][1]["cachePoint"] == {"type": "default"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_cache_control_hook_user_message():
|
||||
anthropic_cache_control_hook = AnthropicCacheControlHook()
|
||||
litellm.callbacks = [anthropic_cache_control_hook]
|
||||
|
||||
# Mock response data
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Here is my analysis of the key terms and conditions...",
|
||||
}
|
||||
},
|
||||
"stopReason": "stop_sequence",
|
||||
"usage": {
|
||||
"inputTokens": 100,
|
||||
"outputTokens": 200,
|
||||
"totalTokens": 300,
|
||||
"cacheReadInputTokens": 100,
|
||||
"cacheWriteInputTokens": 200,
|
||||
},
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
|
||||
# Mock AsyncHTTPHandler.post method
|
||||
client = AsyncHTTPHandler()
|
||||
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||
response = await litellm.acompletion(
|
||||
model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
messages=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an AI assistant tasked with analyzing legal documents.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what are the key terms and conditions in this agreement? <very_long_text>",
|
||||
},
|
||||
],
|
||||
cache_control_injection_points=[
|
||||
{
|
||||
"location": "message",
|
||||
"role": "user",
|
||||
},
|
||||
],
|
||||
client=client,
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||
|
||||
print("request_body: ", json.dumps(request_body, indent=4))
|
||||
|
||||
# Verify the request body
|
||||
assert request_body["messages"][1]["content"][1]["cachePoint"] == {
|
||||
"type": "default"
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue