diff --git a/litellm/integrations/anthropic_cache_control_hook.py b/litellm/integrations/anthropic_cache_control_hook.py new file mode 100644 index 0000000000..ba6b9ae2f6 --- /dev/null +++ b/litellm/integrations/anthropic_cache_control_hook.py @@ -0,0 +1,117 @@ +""" +This hook is used to inject cache control directives into the messages of a chat completion. + +Users can define +- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points. + +""" + +import copy +from typing import List, Optional, Tuple, cast + +from litellm.integrations.custom_prompt_management import CustomPromptManagement +from litellm.types.integrations.anthropic_cache_control_hook import ( + CacheControlInjectionPoint, + CacheControlMessageInjectionPoint, +) +from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent +from litellm.types.utils import StandardCallbackDynamicParams + + +class AnthropicCacheControlHook(CustomPromptManagement): + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Apply cache control directives based on specified injection points. + + Returns: + - model: str - the model to use + - messages: List[AllMessageValues] - messages with applied cache controls + - non_default_params: dict - params with any global cache controls + """ + # Extract cache control injection points + injection_points: List[CacheControlInjectionPoint] = non_default_params.pop( + "cache_control_injection_points", [] + ) + if not injection_points: + return model, messages, non_default_params + + # Create a deep copy of messages to avoid modifying the original list + processed_messages = copy.deepcopy(messages) + + # Process message-level cache controls + for point in injection_points: + if point.get("location") == "message": + point = cast(CacheControlMessageInjectionPoint, point) + processed_messages = self._process_message_injection( + point=point, messages=processed_messages + ) + + return model, processed_messages, non_default_params + + @staticmethod + def _process_message_injection( + point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues] + ) -> List[AllMessageValues]: + """Process message-level cache control injection.""" + control = point.get("control", {}) or ChatCompletionCachedContent( + type="ephemeral" + ) + targetted_index = point.get("index", None) + targetted_role = point.get("role", None) + + # Case 1: Target by specific index + if targetted_index is not None: + if 0 <= targetted_index < len(messages): + messages[targetted_index] = ( + AnthropicCacheControlHook._safe_insert_cache_control_in_message( + messages[targetted_index], control + ) + ) + # Case 2: Target by role + elif targetted_role is not None: + for msg in messages: + if msg.get("role") == targetted_role: + msg = ( + AnthropicCacheControlHook._safe_insert_cache_control_in_message( + msg, control + ) + ) + return messages + + @staticmethod + def _safe_insert_cache_control_in_message( + message: AllMessageValues, control: ChatCompletionCachedContent + ) -> AllMessageValues: + """ + Safe way to insert cache control in a message + + OpenAI Message content can be either: + - string + - list of objects + + This method handles inserting cache control in both cases. + """ + message_content = message.get("content", None) + + # 1. if string, insert cache control in the message + if isinstance(message_content, str): + message["cache_control"] = control # type: ignore + # 2. list of objects + elif isinstance(message_content, list): + for content_item in message_content: + if isinstance(content_item, dict): + content_item["cache_control"] = control # type: ignore + return message + + @property + def integration_name(self) -> str: + """Return the integration name for this hook.""" + return "anthropic-cache-control-hook" diff --git a/litellm/types/integrations/anthropic_cache_control_hook.py b/litellm/types/integrations/anthropic_cache_control_hook.py new file mode 100644 index 0000000000..edbd84a485 --- /dev/null +++ b/litellm/types/integrations/anthropic_cache_control_hook.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional, TypedDict, Union + +from litellm.types.llms.openai import ChatCompletionCachedContent + + +class CacheControlMessageInjectionPoint(TypedDict): + """Type for message-level injection points.""" + + location: Literal["message"] + role: Optional[ + Literal["user", "system", "assistant"] + ] # Optional: target by role (user, system, assistant) + index: Optional[int] # Optional: target by specific index + control: Optional[ChatCompletionCachedContent] + + +CacheControlInjectionPoint = CacheControlMessageInjectionPoint