mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
118 lines
4.5 KiB
Python
118 lines
4.5 KiB
Python
"""
|
|
This hook is used to inject cache control directives into the messages of a chat completion.
|
|
|
|
Users can define
|
|
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points.
|
|
|
|
"""
|
|
|
|
import copy
|
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
|
|
|
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
|
from litellm.types.integrations.anthropic_cache_control_hook import (
|
|
CacheControlInjectionPoint,
|
|
CacheControlMessageInjectionPoint,
|
|
)
|
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent
|
|
from litellm.types.utils import StandardCallbackDynamicParams
|
|
|
|
|
|
class AnthropicCacheControlHook(CustomPromptManagement):
|
|
def get_chat_completion_prompt(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
non_default_params: dict,
|
|
prompt_id: str,
|
|
prompt_variables: Optional[dict],
|
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
|
"""
|
|
Apply cache control directives based on specified injection points.
|
|
|
|
Returns:
|
|
- model: str - the model to use
|
|
- messages: List[AllMessageValues] - messages with applied cache controls
|
|
- non_default_params: dict - params with any global cache controls
|
|
"""
|
|
# Extract cache control injection points
|
|
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop(
|
|
"cache_control_injection_points", []
|
|
)
|
|
if not injection_points:
|
|
return model, messages, non_default_params
|
|
|
|
# Create a deep copy of messages to avoid modifying the original list
|
|
processed_messages = copy.deepcopy(messages)
|
|
|
|
# Process message-level cache controls
|
|
for point in injection_points:
|
|
if point.get("location") == "message":
|
|
point = cast(CacheControlMessageInjectionPoint, point)
|
|
processed_messages = self._process_message_injection(
|
|
point=point, messages=processed_messages
|
|
)
|
|
|
|
return model, processed_messages, non_default_params
|
|
|
|
@staticmethod
|
|
def _process_message_injection(
|
|
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues]
|
|
) -> List[AllMessageValues]:
|
|
"""Process message-level cache control injection."""
|
|
control: ChatCompletionCachedContent = point.get(
|
|
"control", None
|
|
) or ChatCompletionCachedContent(type="ephemeral")
|
|
targetted_index = point.get("index", None)
|
|
targetted_index = point.get("index", None)
|
|
targetted_role = point.get("role", None)
|
|
|
|
# Case 1: Target by specific index
|
|
if targetted_index is not None:
|
|
if 0 <= targetted_index < len(messages):
|
|
messages[targetted_index] = (
|
|
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
|
messages[targetted_index], control
|
|
)
|
|
)
|
|
# Case 2: Target by role
|
|
elif targetted_role is not None:
|
|
for msg in messages:
|
|
if msg.get("role") == targetted_role:
|
|
msg = (
|
|
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
|
message=msg, control=control
|
|
)
|
|
)
|
|
return messages
|
|
|
|
@staticmethod
|
|
def _safe_insert_cache_control_in_message(
|
|
message: AllMessageValues, control: ChatCompletionCachedContent
|
|
) -> AllMessageValues:
|
|
"""
|
|
Safe way to insert cache control in a message
|
|
|
|
OpenAI Message content can be either:
|
|
- string
|
|
- list of objects
|
|
|
|
This method handles inserting cache control in both cases.
|
|
"""
|
|
message_content = message.get("content", None)
|
|
|
|
# 1. if string, insert cache control in the message
|
|
if isinstance(message_content, str):
|
|
message["cache_control"] = control # type: ignore
|
|
# 2. list of objects
|
|
elif isinstance(message_content, list):
|
|
for content_item in message_content:
|
|
if isinstance(content_item, dict):
|
|
content_item["cache_control"] = control # type: ignore
|
|
return message
|
|
|
|
@property
|
|
def integration_name(self) -> str:
|
|
"""Return the integration name for this hook."""
|
|
return "anthropic-cache-control-hook"
|