[Feat] Add support for cache_control_injection_points for Anthropic API, Bedrock API (#9996)

* test_anthropic_cache_control_hook_system_message

* test_anthropic_cache_control_hook.py

* should_run_prompt_management_hooks

* fix should_run_prompt_management_hooks

* test_anthropic_cache_control_hook_specific_index

* fix test

* fix linting errors

* ChatCompletionCachedContent
This commit is contained in:
Ishaan Jaff 2025-04-14 20:50:13 -07:00 committed by GitHub
parent 2ed593e052
commit 6cfa50d278
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 453 additions and 149 deletions

View file

@ -0,0 +1,118 @@
"""
This hook is used to inject cache control directives into the messages of a chat completion.
Users can define
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points.
"""
import copy
from typing import Any, Dict, List, Optional, Tuple, cast
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.integrations.anthropic_cache_control_hook import (
CacheControlInjectionPoint,
CacheControlMessageInjectionPoint,
)
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent
from litellm.types.utils import StandardCallbackDynamicParams
class AnthropicCacheControlHook(CustomPromptManagement):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Apply cache control directives based on specified injection points.
Returns:
- model: str - the model to use
- messages: List[AllMessageValues] - messages with applied cache controls
- non_default_params: dict - params with any global cache controls
"""
# Extract cache control injection points
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop(
"cache_control_injection_points", []
)
if not injection_points:
return model, messages, non_default_params
# Create a deep copy of messages to avoid modifying the original list
processed_messages = copy.deepcopy(messages)
# Process message-level cache controls
for point in injection_points:
if point.get("location") == "message":
point = cast(CacheControlMessageInjectionPoint, point)
processed_messages = self._process_message_injection(
point=point, messages=processed_messages
)
return model, processed_messages, non_default_params
@staticmethod
def _process_message_injection(
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""Process message-level cache control injection."""
control: ChatCompletionCachedContent = point.get(
"control", None
) or ChatCompletionCachedContent(type="ephemeral")
targetted_index = point.get("index", None)
targetted_index = point.get("index", None)
targetted_role = point.get("role", None)
# Case 1: Target by specific index
if targetted_index is not None:
if 0 <= targetted_index < len(messages):
messages[targetted_index] = (
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
messages[targetted_index], control
)
)
# Case 2: Target by role
elif targetted_role is not None:
for msg in messages:
if msg.get("role") == targetted_role:
msg = (
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
message=msg, control=control
)
)
return messages
@staticmethod
def _safe_insert_cache_control_in_message(
message: AllMessageValues, control: ChatCompletionCachedContent
) -> AllMessageValues:
"""
Safe way to insert cache control in a message
OpenAI Message content can be either:
- string
- list of objects
This method handles inserting cache control in both cases.
"""
message_content = message.get("content", None)
# 1. if string, insert cache control in the message
if isinstance(message_content, str):
message["cache_control"] = control # type: ignore
# 2. list of objects
elif isinstance(message_content, list):
for content_item in message_content:
if isinstance(content_item, dict):
content_item["cache_control"] = control # type: ignore
return message
@property
def integration_name(self) -> str:
"""Return the integration name for this hook."""
return "anthropic-cache-control-hook"

View file

@ -249,9 +249,9 @@ class Logging(LiteLLMLoggingBaseClass):
self.litellm_trace_id = litellm_trace_id
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[
Any
] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
[]
) # for generating complete stream response
self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks
@ -455,6 +455,20 @@ class Logging(LiteLLMLoggingBaseClass):
if "custom_llm_provider" in self.model_call_details:
self.custom_llm_provider = self.model_call_details["custom_llm_provider"]
def should_run_prompt_management_hooks(
self,
prompt_id: str,
non_default_params: Dict,
) -> bool:
"""
Return True if prompt management hooks should be run
"""
if prompt_id:
return True
if non_default_params.get("cache_control_injection_points", None):
return True
return False
def get_chat_completion_prompt(
self,
model: str,
@ -557,9 +571,9 @@ class Logging(LiteLLMLoggingBaseClass):
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
self.model_call_details["litellm_params"][
"api_base"
] = self._get_masked_api_base(additional_args.get("api_base", ""))
self.model_call_details["litellm_params"]["api_base"] = (
self._get_masked_api_base(additional_args.get("api_base", ""))
)
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API
@ -588,10 +602,10 @@ class Logging(LiteLLMLoggingBaseClass):
try:
# [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True:
_metadata[
"raw_request"
] = "redacted by litellm. \
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""),
@ -602,32 +616,32 @@ class Logging(LiteLLMLoggingBaseClass):
_metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
ignore_sensitive_headers=True,
),
error=None,
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
ignore_sensitive_headers=True,
),
error=None,
)
)
except Exception as e:
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
error=str(e),
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
error=str(e),
)
)
_metadata[
"raw_request"
] = "Unable to Log \
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
str(e)
str(e)
)
)
if self.logger_fn and callable(self.logger_fn):
try:
@ -957,9 +971,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
return None
try:
@ -984,9 +998,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
return None
@ -1046,9 +1060,9 @@ class Logging(LiteLLMLoggingBaseClass):
end_time = datetime.datetime.now()
if self.completion_start_time is None:
self.completion_start_time = end_time
self.model_call_details[
"completion_start_time"
] = self.completion_start_time
self.model_call_details["completion_start_time"] = (
self.completion_start_time
)
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit
@ -1127,39 +1141,39 @@ class Logging(LiteLLMLoggingBaseClass):
"response_cost"
]
else:
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=logging_result)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=logging_result)
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=logging_result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=logging_result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
)
elif isinstance(result, dict) or isinstance(result, list):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
)
elif standard_logging_object is not None:
self.model_call_details[
"standard_logging_object"
] = standard_logging_object
self.model_call_details["standard_logging_object"] = (
standard_logging_object
)
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None
@ -1215,23 +1229,23 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=complete_streaming_response)
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=complete_streaming_response)
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
)
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_success_callbacks,
@ -1580,10 +1594,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
)
result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event(
@ -1623,10 +1637,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
)
result = self.model_call_details["complete_response"]
@ -1733,9 +1747,9 @@ class Logging(LiteLLMLoggingBaseClass):
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details[
"async_complete_streaming_response"
] = complete_streaming_response
self.model_call_details["async_complete_streaming_response"] = (
complete_streaming_response
)
try:
if self.model_call_details.get("cache_hit", False) is True:
self.model_call_details["response_cost"] = 0.0
@ -1745,10 +1759,10 @@ class Logging(LiteLLMLoggingBaseClass):
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(
result=complete_streaming_response
self.model_call_details["response_cost"] = (
self._response_cost_calculator(
result=complete_streaming_response
)
)
verbose_logger.debug(
@ -1761,16 +1775,16 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["response_cost"] = None
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
)
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
@ -1976,18 +1990,18 @@ class Logging(LiteLLMLoggingBaseClass):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
)
return start_time, end_time
@ -2753,9 +2767,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
endpoint=arize_config.endpoint,
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
)
for callback in _in_memory_loggers:
if (
isinstance(callback, ArizeLogger)
@ -2779,9 +2793,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
# auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None:
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = arize_phoenix_config.otlp_auth_headers
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
arize_phoenix_config.otlp_auth_headers
)
for callback in _in_memory_loggers:
if (
@ -2872,9 +2886,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
exporter="otlp_http",
endpoint="https://langtrace.ai/api/trace",
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
)
for callback in _in_memory_loggers:
if (
isinstance(callback, OpenTelemetry)
@ -3369,10 +3383,10 @@ class StandardLoggingPayloadSetup:
for key in StandardLoggingHiddenParams.__annotations__.keys():
if key in hidden_params:
if key == "additional_headers":
clean_hidden_params[
"additional_headers"
] = StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
clean_hidden_params["additional_headers"] = (
StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
)
else:
clean_hidden_params[key] = hidden_params[key] # type: ignore
@ -3651,7 +3665,7 @@ def emit_standard_logging_payload(payload: StandardLoggingPayload):
def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]]
metadata: Optional[Dict[str, Any]],
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@ -3715,9 +3729,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
):
for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[
k
] = "scrubbed_by_litellm_for_sensitive_keys"
cleaned_user_api_key_metadata[k] = (
"scrubbed_by_litellm_for_sensitive_keys"
)
else:
cleaned_user_api_key_metadata[k] = v

View file

@ -954,7 +954,11 @@ def completion( # type: ignore # noqa: PLR0915
non_default_params = get_non_default_completion_params(kwargs=kwargs)
litellm_params = {} # used to prevent unbound var errors
## PROMPT MANAGEMENT HOOKS ##
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and (
litellm_logging_obj.should_run_prompt_management_hooks(
prompt_id=prompt_id, non_default_params=non_default_params
)
):
(
model,
messages,
@ -2654,9 +2658,9 @@ def completion( # type: ignore # noqa: PLR0915
"aws_region_name" not in optional_params
or optional_params["aws_region_name"] is None
):
optional_params[
"aws_region_name"
] = aws_bedrock_client.meta.region_name
optional_params["aws_region_name"] = (
aws_bedrock_client.meta.region_name
)
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse":
@ -4363,9 +4367,9 @@ def adapter_completion(
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
translated_response: Optional[
Union[BaseModel, AdapterCompletionStreamWrapper]
] = None
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
None
)
if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params(
response=response
@ -5785,9 +5789,9 @@ def stream_chunk_builder( # noqa: PLR0915
]
if len(content_chunks) > 0:
response["choices"][0]["message"][
"content"
] = processor.get_combined_content(content_chunks)
response["choices"][0]["message"]["content"] = (
processor.get_combined_content(content_chunks)
)
reasoning_chunks = [
chunk
@ -5798,9 +5802,9 @@ def stream_chunk_builder( # noqa: PLR0915
]
if len(reasoning_chunks) > 0:
response["choices"][0]["message"][
"reasoning_content"
] = processor.get_combined_reasoning_content(reasoning_chunks)
response["choices"][0]["message"]["reasoning_content"] = (
processor.get_combined_reasoning_content(reasoning_chunks)
)
audio_chunks = [
chunk

View file

@ -0,0 +1,17 @@
from typing import Literal, Optional, TypedDict, Union
from litellm.types.llms.openai import ChatCompletionCachedContent
class CacheControlMessageInjectionPoint(TypedDict):
"""Type for message-level injection points."""
location: Literal["message"]
role: Optional[
Literal["user", "system", "assistant"]
] # Optional: target by role (user, system, assistant)
index: Optional[int] # Optional: target by specific index
control: Optional[ChatCompletionCachedContent]
CacheControlInjectionPoint = CacheControlMessageInjectionPoint

View file

@ -0,0 +1,151 @@
import datetime
import json
import os
import sys
import unittest
from typing import List, Optional, Tuple
from unittest.mock import ANY, MagicMock, Mock, patch
import httpx
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import litellm
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_system_message():
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Here is my analysis of the key terms and conditions...",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 100,
"outputTokens": 200,
"totalTokens": 300,
"cacheReadInputTokens": 100,
"cacheWriteInputTokens": 200,
},
}
mock_response.status_code = 200
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
response = await litellm.acompletion(
model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"type": "text",
"text": "Here is the full text of a complex legal agreement",
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
],
cache_control_injection_points=[
{
"location": "message",
"role": "system",
},
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
print("request_body: ", json.dumps(request_body, indent=4))
# Verify the request body
assert request_body["system"][1]["cachePoint"] == {"type": "default"}
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_user_message():
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Here is my analysis of the key terms and conditions...",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 100,
"outputTokens": 200,
"totalTokens": 300,
"cacheReadInputTokens": 100,
"cacheWriteInputTokens": 200,
},
}
mock_response.status_code = 200
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
response = await litellm.acompletion(
model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
messages=[
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement? <very_long_text>",
},
],
cache_control_injection_points=[
{
"location": "message",
"role": "user",
},
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
print("request_body: ", json.dumps(request_body, indent=4))
# Verify the request body
assert request_body["messages"][1]["content"][1]["cachePoint"] == {
"type": "default"
}