fixes for using cache control on ui + backend

This commit is contained in:
Ishaan Jaff 2025-04-14 19:21:29 -07:00
parent 94c3de90bb
commit a81f7300b9
4 changed files with 78 additions and 8 deletions

View file

@ -113,6 +113,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"pagerduty",
"humanloop",
"gcs_pubsub",
"anthropic_cache_control_hook",
]
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
_known_custom_logger_compatible_callbacks: List = list(

View file

@ -7,8 +7,9 @@ Users can define
"""
import copy
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Dict, List, Optional, Tuple, Union, cast
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.integrations.anthropic_cache_control_hook import (
CacheControlInjectionPoint,
@ -64,8 +65,15 @@ class AnthropicCacheControlHook(CustomPromptManagement):
control: ChatCompletionCachedContent = point.get(
"control", None
) or ChatCompletionCachedContent(type="ephemeral")
targetted_index = point.get("index", None)
targetted_index = point.get("index", None)
_targetted_index: Optional[Union[int, str]] = point.get("index", None)
targetted_index: Optional[int] = None
if isinstance(_targetted_index, str):
if _targetted_index.isdigit():
targetted_index = int(_targetted_index)
else:
targetted_index = _targetted_index
targetted_role = point.get("role", None)
# Case 1: Target by specific index
@ -115,4 +123,28 @@ class AnthropicCacheControlHook(CustomPromptManagement):
@property
def integration_name(self) -> str:
"""Return the integration name for this hook."""
return "anthropic-cache-control-hook"
return "anthropic_cache_control_hook"
@staticmethod
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool:
if non_default_params.get("cache_control_injection_points", None):
return True
return False
@staticmethod
def get_custom_logger_for_anthropic_cache_control_hook(
non_default_params: Dict,
) -> Optional[CustomLogger]:
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
non_default_params
):
return _init_custom_logger_compatible_class(
logging_integration="anthropic_cache_control_hook",
internal_usage_cache=None,
llm_router=None,
)
return None

View file

@ -36,6 +36,7 @@ from litellm.cost_calculator import (
RealtimeAPITokenUsageProcessor,
_select_model_name_for_cost_calc,
)
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
@ -465,7 +466,9 @@ class Logging(LiteLLMLoggingBaseClass):
"""
if prompt_id:
return True
if non_default_params.get("cache_control_injection_points", None):
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
non_default_params
):
return True
return False
@ -480,8 +483,11 @@ class Logging(LiteLLMLoggingBaseClass):
) -> Tuple[str, List[AllMessageValues], dict]:
custom_logger = (
prompt_management_logger
or self.get_custom_logger_for_prompt_management(model)
or self.get_custom_logger_for_prompt_management(
model=model, non_default_params=non_default_params
)
)
if custom_logger:
(
model,
@ -499,7 +505,7 @@ class Logging(LiteLLMLoggingBaseClass):
return model, messages, non_default_params
def get_custom_logger_for_prompt_management(
self, model: str
self, model: str, non_default_params: Dict
) -> Optional[CustomLogger]:
"""
Get a custom logger for prompt management based on model name or available callbacks.
@ -534,6 +540,26 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["prompt_integration"] = logger.__class__.__name__
return logger
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
non_default_params
):
self.model_call_details["prompt_integration"] = (
anthropic_cache_control_logger.__class__.__name__
)
return anthropic_cache_control_logger
return None
def get_custom_logger_for_anthropic_cache_control_hook(
self, non_default_params: Dict
) -> Optional[CustomLogger]:
if non_default_params.get("cache_control_injection_points", None):
custom_logger = _init_custom_logger_compatible_class(
logging_integration="anthropic_cache_control_hook",
internal_usage_cache=None,
llm_router=None,
)
return custom_logger
return None
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
@ -2922,6 +2948,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args)
_in_memory_loggers.append(pagerduty_logger)
return pagerduty_logger # type: ignore
elif logging_integration == "anthropic_cache_control_hook":
for callback in _in_memory_loggers:
if isinstance(callback, AnthropicCacheControlHook):
return callback
anthropic_cache_control_hook = AnthropicCacheControlHook()
_in_memory_loggers.append(anthropic_cache_control_hook)
return anthropic_cache_control_hook # type: ignore
elif logging_integration == "gcs_pubsub":
for callback in _in_memory_loggers:
if isinstance(callback, GcsPubSubLogger):
@ -3060,6 +3093,10 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
for callback in _in_memory_loggers:
if isinstance(callback, PagerDutyAlerting):
return callback
elif logging_integration == "anthropic_cache_control_hook":
for callback in _in_memory_loggers:
if isinstance(callback, AnthropicCacheControlHook):
return callback
elif logging_integration == "gcs_pubsub":
for callback in _in_memory_loggers:
if isinstance(callback, GcsPubSubLogger):

View file

@ -10,7 +10,7 @@ class CacheControlMessageInjectionPoint(TypedDict):
role: Optional[
Literal["user", "system", "assistant"]
] # Optional: target by role (user, system, assistant)
index: Optional[int] # Optional: target by specific index
index: Optional[Union[int, str]] # Optional: target by specific index
control: Optional[ChatCompletionCachedContent]