mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
test_anthropic_cache_control_hook.py
This commit is contained in:
parent
53ce0d434e
commit
d986b5d6b1
2 changed files with 134 additions and 0 deletions
117
litellm/integrations/anthropic_cache_control_hook.py
Normal file
117
litellm/integrations/anthropic_cache_control_hook.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
This hook is used to inject cache control directives into the messages of a chat completion.
|
||||
|
||||
Users can define
|
||||
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points.
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.types.integrations.anthropic_cache_control_hook import (
|
||||
CacheControlInjectionPoint,
|
||||
CacheControlMessageInjectionPoint,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
class AnthropicCacheControlHook(CustomPromptManagement):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Apply cache control directives based on specified injection points.
|
||||
|
||||
Returns:
|
||||
- model: str - the model to use
|
||||
- messages: List[AllMessageValues] - messages with applied cache controls
|
||||
- non_default_params: dict - params with any global cache controls
|
||||
"""
|
||||
# Extract cache control injection points
|
||||
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop(
|
||||
"cache_control_injection_points", []
|
||||
)
|
||||
if not injection_points:
|
||||
return model, messages, non_default_params
|
||||
|
||||
# Create a deep copy of messages to avoid modifying the original list
|
||||
processed_messages = copy.deepcopy(messages)
|
||||
|
||||
# Process message-level cache controls
|
||||
for point in injection_points:
|
||||
if point.get("location") == "message":
|
||||
point = cast(CacheControlMessageInjectionPoint, point)
|
||||
processed_messages = self._process_message_injection(
|
||||
point=point, messages=processed_messages
|
||||
)
|
||||
|
||||
return model, processed_messages, non_default_params
|
||||
|
||||
@staticmethod
|
||||
def _process_message_injection(
|
||||
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""Process message-level cache control injection."""
|
||||
control = point.get("control", {}) or ChatCompletionCachedContent(
|
||||
type="ephemeral"
|
||||
)
|
||||
targetted_index = point.get("index", None)
|
||||
targetted_role = point.get("role", None)
|
||||
|
||||
# Case 1: Target by specific index
|
||||
if targetted_index is not None:
|
||||
if 0 <= targetted_index < len(messages):
|
||||
messages[targetted_index] = (
|
||||
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
messages[targetted_index], control
|
||||
)
|
||||
)
|
||||
# Case 2: Target by role
|
||||
elif targetted_role is not None:
|
||||
for msg in messages:
|
||||
if msg.get("role") == targetted_role:
|
||||
msg = (
|
||||
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
msg, control
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _safe_insert_cache_control_in_message(
|
||||
message: AllMessageValues, control: ChatCompletionCachedContent
|
||||
) -> AllMessageValues:
|
||||
"""
|
||||
Safe way to insert cache control in a message
|
||||
|
||||
OpenAI Message content can be either:
|
||||
- string
|
||||
- list of objects
|
||||
|
||||
This method handles inserting cache control in both cases.
|
||||
"""
|
||||
message_content = message.get("content", None)
|
||||
|
||||
# 1. if string, insert cache control in the message
|
||||
if isinstance(message_content, str):
|
||||
message["cache_control"] = control # type: ignore
|
||||
# 2. list of objects
|
||||
elif isinstance(message_content, list):
|
||||
for content_item in message_content:
|
||||
if isinstance(content_item, dict):
|
||||
content_item["cache_control"] = control # type: ignore
|
||||
return message
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Return the integration name for this hook."""
|
||||
return "anthropic-cache-control-hook"
|
17
litellm/types/integrations/anthropic_cache_control_hook.py
Normal file
17
litellm/types/integrations/anthropic_cache_control_hook.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
from litellm.types.llms.openai import ChatCompletionCachedContent
|
||||
|
||||
|
||||
class CacheControlMessageInjectionPoint(TypedDict):
|
||||
"""Type for message-level injection points."""
|
||||
|
||||
location: Literal["message"]
|
||||
role: Optional[
|
||||
Literal["user", "system", "assistant"]
|
||||
] # Optional: target by role (user, system, assistant)
|
||||
index: Optional[int] # Optional: target by specific index
|
||||
control: Optional[ChatCompletionCachedContent]
|
||||
|
||||
|
||||
CacheControlInjectionPoint = CacheControlMessageInjectionPoint
|
Loading…
Add table
Add a link
Reference in a new issue