diff --git a/litellm/integrations/anthropic_cache_control_hook.py b/litellm/integrations/anthropic_cache_control_hook.py new file mode 100644 index 0000000000..f41d579cdf --- /dev/null +++ b/litellm/integrations/anthropic_cache_control_hook.py @@ -0,0 +1,118 @@ +""" +This hook is used to inject cache control directives into the messages of a chat completion. + +Users can define +- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points. + +""" + +import copy +from typing import Any, Dict, List, Optional, Tuple, cast + +from litellm.integrations.custom_prompt_management import CustomPromptManagement +from litellm.types.integrations.anthropic_cache_control_hook import ( + CacheControlInjectionPoint, + CacheControlMessageInjectionPoint, +) +from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent +from litellm.types.utils import StandardCallbackDynamicParams + + +class AnthropicCacheControlHook(CustomPromptManagement): + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Apply cache control directives based on specified injection points. + + Returns: + - model: str - the model to use + - messages: List[AllMessageValues] - messages with applied cache controls + - non_default_params: dict - params with any global cache controls + """ + # Extract cache control injection points + injection_points: List[CacheControlInjectionPoint] = non_default_params.pop( + "cache_control_injection_points", [] + ) + if not injection_points: + return model, messages, non_default_params + + # Create a deep copy of messages to avoid modifying the original list + processed_messages = copy.deepcopy(messages) + + # Process message-level cache controls + for point in injection_points: + if point.get("location") == "message": + point = cast(CacheControlMessageInjectionPoint, point) + processed_messages = self._process_message_injection( + point=point, messages=processed_messages + ) + + return model, processed_messages, non_default_params + + @staticmethod + def _process_message_injection( + point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues] + ) -> List[AllMessageValues]: + """Process message-level cache control injection.""" + control: ChatCompletionCachedContent = point.get( + "control", None + ) or ChatCompletionCachedContent(type="ephemeral") + targetted_index = point.get("index", None) + targetted_index = point.get("index", None) + targetted_role = point.get("role", None) + + # Case 1: Target by specific index + if targetted_index is not None: + if 0 <= targetted_index < len(messages): + messages[targetted_index] = ( + AnthropicCacheControlHook._safe_insert_cache_control_in_message( + messages[targetted_index], control + ) + ) + # Case 2: Target by role + elif targetted_role is not None: + for msg in messages: + if msg.get("role") == targetted_role: + msg = ( + AnthropicCacheControlHook._safe_insert_cache_control_in_message( + message=msg, control=control + ) + ) + return messages + + @staticmethod + def _safe_insert_cache_control_in_message( + message: AllMessageValues, control: ChatCompletionCachedContent + ) -> AllMessageValues: + """ + Safe way to insert cache control in a message + + OpenAI Message content can be either: + - string + - list of objects + + This method handles inserting cache control in both cases. + """ + message_content = message.get("content", None) + + # 1. if string, insert cache control in the message + if isinstance(message_content, str): + message["cache_control"] = control # type: ignore + # 2. list of objects + elif isinstance(message_content, list): + for content_item in message_content: + if isinstance(content_item, dict): + content_item["cache_control"] = control # type: ignore + return message + + @property + def integration_name(self) -> str: + """Return the integration name for this hook.""" + return "anthropic-cache-control-hook" diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index af59901b33..c5d59adca5 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -249,9 +249,9 @@ class Logging(LiteLLMLoggingBaseClass): self.litellm_trace_id = litellm_trace_id self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response - self.sync_streaming_chunks: List[ - Any - ] = [] # for generating complete stream response + self.sync_streaming_chunks: List[Any] = ( + [] + ) # for generating complete stream response self.log_raw_request_response = log_raw_request_response # Initialize dynamic callbacks @@ -455,6 +455,20 @@ class Logging(LiteLLMLoggingBaseClass): if "custom_llm_provider" in self.model_call_details: self.custom_llm_provider = self.model_call_details["custom_llm_provider"] + def should_run_prompt_management_hooks( + self, + prompt_id: str, + non_default_params: Dict, + ) -> bool: + """ + Return True if prompt management hooks should be run + """ + if prompt_id: + return True + if non_default_params.get("cache_control_injection_points", None): + return True + return False + def get_chat_completion_prompt( self, model: str, @@ -557,9 +571,9 @@ class Logging(LiteLLMLoggingBaseClass): model ): # if model name was changes pre-call, overwrite the initial model call name with the new one self.model_call_details["model"] = model - self.model_call_details["litellm_params"][ - "api_base" - ] = self._get_masked_api_base(additional_args.get("api_base", "")) + self.model_call_details["litellm_params"]["api_base"] = ( + self._get_masked_api_base(additional_args.get("api_base", "")) + ) def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 # Log the exact input to the LLM API @@ -588,10 +602,10 @@ class Logging(LiteLLMLoggingBaseClass): try: # [Non-blocking Extra Debug Information in metadata] if turn_off_message_logging is True: - _metadata[ - "raw_request" - ] = "redacted by litellm. \ + _metadata["raw_request"] = ( + "redacted by litellm. \ 'litellm.turn_off_message_logging=True'" + ) else: curl_command = self._get_request_curl_command( api_base=additional_args.get("api_base", ""), @@ -602,32 +616,32 @@ class Logging(LiteLLMLoggingBaseClass): _metadata["raw_request"] = str(curl_command) # split up, so it's easier to parse in the UI - self.model_call_details[ - "raw_request_typed_dict" - ] = RawRequestTypedDict( - raw_request_api_base=str( - additional_args.get("api_base") or "" - ), - raw_request_body=self._get_raw_request_body( - additional_args.get("complete_input_dict", {}) - ), - raw_request_headers=self._get_masked_headers( - additional_args.get("headers", {}) or {}, - ignore_sensitive_headers=True, - ), - error=None, + self.model_call_details["raw_request_typed_dict"] = ( + RawRequestTypedDict( + raw_request_api_base=str( + additional_args.get("api_base") or "" + ), + raw_request_body=self._get_raw_request_body( + additional_args.get("complete_input_dict", {}) + ), + raw_request_headers=self._get_masked_headers( + additional_args.get("headers", {}) or {}, + ignore_sensitive_headers=True, + ), + error=None, + ) ) except Exception as e: - self.model_call_details[ - "raw_request_typed_dict" - ] = RawRequestTypedDict( - error=str(e), + self.model_call_details["raw_request_typed_dict"] = ( + RawRequestTypedDict( + error=str(e), + ) ) - _metadata[ - "raw_request" - ] = "Unable to Log \ + _metadata["raw_request"] = ( + "Unable to Log \ raw request: {}".format( - str(e) + str(e) + ) ) if self.logger_fn and callable(self.logger_fn): try: @@ -957,9 +971,9 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details[ - "response_cost_failure_debug_information" - ] = debug_info + self.model_call_details["response_cost_failure_debug_information"] = ( + debug_info + ) return None try: @@ -984,9 +998,9 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details[ - "response_cost_failure_debug_information" - ] = debug_info + self.model_call_details["response_cost_failure_debug_information"] = ( + debug_info + ) return None @@ -1046,9 +1060,9 @@ class Logging(LiteLLMLoggingBaseClass): end_time = datetime.datetime.now() if self.completion_start_time is None: self.completion_start_time = end_time - self.model_call_details[ - "completion_start_time" - ] = self.completion_start_time + self.model_call_details["completion_start_time"] = ( + self.completion_start_time + ) self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time self.model_call_details["cache_hit"] = cache_hit @@ -1127,39 +1141,39 @@ class Logging(LiteLLMLoggingBaseClass): "response_cost" ] else: - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator(result=logging_result) + self.model_call_details["response_cost"] = ( + self._response_cost_calculator(result=logging_result) + ) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=logging_result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=logging_result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) ) elif isinstance(result, dict) or isinstance(result, list): ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) ) elif standard_logging_object is not None: - self.model_call_details[ - "standard_logging_object" - ] = standard_logging_object + self.model_call_details["standard_logging_object"] = ( + standard_logging_object + ) else: # streaming chunks + image gen. self.model_call_details["response_cost"] = None @@ -1215,23 +1229,23 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" ) - self.model_call_details[ - "complete_streaming_response" - ] = complete_streaming_response - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator(result=complete_streaming_response) + self.model_call_details["complete_streaming_response"] = ( + complete_streaming_response + ) + self.model_call_details["response_cost"] = ( + self._response_cost_calculator(result=complete_streaming_response) + ) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_success_callbacks, @@ -1580,10 +1594,10 @@ class Logging(LiteLLMLoggingBaseClass): ) else: if self.stream and complete_streaming_response: - self.model_call_details[ - "complete_response" - ] = self.model_call_details.get( - "complete_streaming_response", {} + self.model_call_details["complete_response"] = ( + self.model_call_details.get( + "complete_streaming_response", {} + ) ) result = self.model_call_details["complete_response"] openMeterLogger.log_success_event( @@ -1623,10 +1637,10 @@ class Logging(LiteLLMLoggingBaseClass): ) else: if self.stream and complete_streaming_response: - self.model_call_details[ - "complete_response" - ] = self.model_call_details.get( - "complete_streaming_response", {} + self.model_call_details["complete_response"] = ( + self.model_call_details.get( + "complete_streaming_response", {} + ) ) result = self.model_call_details["complete_response"] @@ -1733,9 +1747,9 @@ class Logging(LiteLLMLoggingBaseClass): if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details[ - "async_complete_streaming_response" - ] = complete_streaming_response + self.model_call_details["async_complete_streaming_response"] = ( + complete_streaming_response + ) try: if self.model_call_details.get("cache_hit", False) is True: self.model_call_details["response_cost"] = 0.0 @@ -1745,10 +1759,10 @@ class Logging(LiteLLMLoggingBaseClass): model_call_details=self.model_call_details ) # base_model defaults to None if not set on model_info - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator( - result=complete_streaming_response + self.model_call_details["response_cost"] = ( + self._response_cost_calculator( + result=complete_streaming_response + ) ) verbose_logger.debug( @@ -1761,16 +1775,16 @@ class Logging(LiteLLMLoggingBaseClass): self.model_call_details["response_cost"] = None ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_async_success_callbacks, @@ -1976,18 +1990,18 @@ class Logging(LiteLLMLoggingBaseClass): ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj={}, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="failure", - error_str=str(exception), - original_exception=exception, - standard_built_in_tools_params=self.standard_built_in_tools_params, + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=str(exception), + original_exception=exception, + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) ) return start_time, end_time @@ -2753,9 +2767,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 endpoint=arize_config.endpoint, ) - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}" + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + f"space_key={arize_config.space_key},api_key={arize_config.api_key}" + ) for callback in _in_memory_loggers: if ( isinstance(callback, ArizeLogger) @@ -2779,9 +2793,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 # auth can be disabled on local deployments of arize phoenix if arize_phoenix_config.otlp_auth_headers is not None: - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = arize_phoenix_config.otlp_auth_headers + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + arize_phoenix_config.otlp_auth_headers + ) for callback in _in_memory_loggers: if ( @@ -2872,9 +2886,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 exporter="otlp_http", endpoint="https://langtrace.ai/api/trace", ) - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + f"api_key={os.getenv('LANGTRACE_API_KEY')}" + ) for callback in _in_memory_loggers: if ( isinstance(callback, OpenTelemetry) @@ -3369,10 +3383,10 @@ class StandardLoggingPayloadSetup: for key in StandardLoggingHiddenParams.__annotations__.keys(): if key in hidden_params: if key == "additional_headers": - clean_hidden_params[ - "additional_headers" - ] = StandardLoggingPayloadSetup.get_additional_headers( - hidden_params[key] + clean_hidden_params["additional_headers"] = ( + StandardLoggingPayloadSetup.get_additional_headers( + hidden_params[key] + ) ) else: clean_hidden_params[key] = hidden_params[key] # type: ignore @@ -3651,7 +3665,7 @@ def emit_standard_logging_payload(payload: StandardLoggingPayload): def get_standard_logging_metadata( - metadata: Optional[Dict[str, Any]] + metadata: Optional[Dict[str, Any]], ) -> StandardLoggingMetadata: """ Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. @@ -3715,9 +3729,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): ): for k, v in metadata["user_api_key_metadata"].items(): if k == "logging": # prevent logging user logging keys - cleaned_user_api_key_metadata[ - k - ] = "scrubbed_by_litellm_for_sensitive_keys" + cleaned_user_api_key_metadata[k] = ( + "scrubbed_by_litellm_for_sensitive_keys" + ) else: cleaned_user_api_key_metadata[k] = v diff --git a/litellm/main.py b/litellm/main.py index 3f1d9a1e76..fd17a283a0 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -954,7 +954,11 @@ def completion( # type: ignore # noqa: PLR0915 non_default_params = get_non_default_completion_params(kwargs=kwargs) litellm_params = {} # used to prevent unbound var errors ## PROMPT MANAGEMENT HOOKS ## - if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None: + if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and ( + litellm_logging_obj.should_run_prompt_management_hooks( + prompt_id=prompt_id, non_default_params=non_default_params + ) + ): ( model, messages, @@ -2654,9 +2658,9 @@ def completion( # type: ignore # noqa: PLR0915 "aws_region_name" not in optional_params or optional_params["aws_region_name"] is None ): - optional_params[ - "aws_region_name" - ] = aws_bedrock_client.meta.region_name + optional_params["aws_region_name"] = ( + aws_bedrock_client.meta.region_name + ) bedrock_route = BedrockModelInfo.get_bedrock_route(model) if bedrock_route == "converse": @@ -4363,9 +4367,9 @@ def adapter_completion( new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore - translated_response: Optional[ - Union[BaseModel, AdapterCompletionStreamWrapper] - ] = None + translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = ( + None + ) if isinstance(response, ModelResponse): translated_response = translation_obj.translate_completion_output_params( response=response @@ -5785,9 +5789,9 @@ def stream_chunk_builder( # noqa: PLR0915 ] if len(content_chunks) > 0: - response["choices"][0]["message"][ - "content" - ] = processor.get_combined_content(content_chunks) + response["choices"][0]["message"]["content"] = ( + processor.get_combined_content(content_chunks) + ) reasoning_chunks = [ chunk @@ -5798,9 +5802,9 @@ def stream_chunk_builder( # noqa: PLR0915 ] if len(reasoning_chunks) > 0: - response["choices"][0]["message"][ - "reasoning_content" - ] = processor.get_combined_reasoning_content(reasoning_chunks) + response["choices"][0]["message"]["reasoning_content"] = ( + processor.get_combined_reasoning_content(reasoning_chunks) + ) audio_chunks = [ chunk diff --git a/litellm/types/integrations/anthropic_cache_control_hook.py b/litellm/types/integrations/anthropic_cache_control_hook.py new file mode 100644 index 0000000000..edbd84a485 --- /dev/null +++ b/litellm/types/integrations/anthropic_cache_control_hook.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional, TypedDict, Union + +from litellm.types.llms.openai import ChatCompletionCachedContent + + +class CacheControlMessageInjectionPoint(TypedDict): + """Type for message-level injection points.""" + + location: Literal["message"] + role: Optional[ + Literal["user", "system", "assistant"] + ] # Optional: target by role (user, system, assistant) + index: Optional[int] # Optional: target by specific index + control: Optional[ChatCompletionCachedContent] + + +CacheControlInjectionPoint = CacheControlMessageInjectionPoint diff --git a/tests/litellm/integrations/test_anthropic_cache_control_hook.py b/tests/litellm/integrations/test_anthropic_cache_control_hook.py new file mode 100644 index 0000000000..fd5f3698ac --- /dev/null +++ b/tests/litellm/integrations/test_anthropic_cache_control_hook.py @@ -0,0 +1,151 @@ +import datetime +import json +import os +import sys +import unittest +from typing import List, Optional, Tuple +from unittest.mock import ANY, MagicMock, Mock, patch + +import httpx +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path +import litellm +from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import StandardCallbackDynamicParams + + +@pytest.mark.asyncio +async def test_anthropic_cache_control_hook_system_message(): + anthropic_cache_control_hook = AnthropicCacheControlHook() + litellm.callbacks = [anthropic_cache_control_hook] + + # Mock response data + mock_response = MagicMock() + mock_response.json.return_value = { + "output": { + "message": { + "role": "assistant", + "content": "Here is my analysis of the key terms and conditions...", + } + }, + "stopReason": "stop_sequence", + "usage": { + "inputTokens": 100, + "outputTokens": 200, + "totalTokens": 300, + "cacheReadInputTokens": 100, + "cacheWriteInputTokens": 200, + }, + } + mock_response.status_code = 200 + + # Mock AsyncHTTPHandler.post method + client = AsyncHTTPHandler() + with patch.object(client, "post", return_value=mock_response) as mock_post: + response = await litellm.acompletion( + model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0", + messages=[ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are an AI assistant tasked with analyzing legal documents.", + }, + { + "type": "text", + "text": "Here is the full text of a complex legal agreement", + }, + ], + }, + { + "role": "user", + "content": "what are the key terms and conditions in this agreement?", + }, + ], + cache_control_injection_points=[ + { + "location": "message", + "role": "system", + }, + ], + client=client, + ) + + mock_post.assert_called_once() + request_body = json.loads(mock_post.call_args.kwargs["data"]) + + print("request_body: ", json.dumps(request_body, indent=4)) + + # Verify the request body + assert request_body["system"][1]["cachePoint"] == {"type": "default"} + + +@pytest.mark.asyncio +async def test_anthropic_cache_control_hook_user_message(): + anthropic_cache_control_hook = AnthropicCacheControlHook() + litellm.callbacks = [anthropic_cache_control_hook] + + # Mock response data + mock_response = MagicMock() + mock_response.json.return_value = { + "output": { + "message": { + "role": "assistant", + "content": "Here is my analysis of the key terms and conditions...", + } + }, + "stopReason": "stop_sequence", + "usage": { + "inputTokens": 100, + "outputTokens": 200, + "totalTokens": 300, + "cacheReadInputTokens": 100, + "cacheWriteInputTokens": 200, + }, + } + mock_response.status_code = 200 + + # Mock AsyncHTTPHandler.post method + client = AsyncHTTPHandler() + with patch.object(client, "post", return_value=mock_response) as mock_post: + response = await litellm.acompletion( + model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0", + messages=[ + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "You are an AI assistant tasked with analyzing legal documents.", + }, + ], + }, + { + "role": "user", + "content": "what are the key terms and conditions in this agreement? ", + }, + ], + cache_control_injection_points=[ + { + "location": "message", + "role": "user", + }, + ], + client=client, + ) + + mock_post.assert_called_once() + request_body = json.loads(mock_post.call_args.kwargs["data"]) + + print("request_body: ", json.dumps(request_body, indent=4)) + + # Verify the request body + assert request_body["messages"][1]["content"][1]["cachePoint"] == { + "type": "default" + }