mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fixes for using cache control on ui + backend
This commit is contained in:
parent
94c3de90bb
commit
a81f7300b9
4 changed files with 78 additions and 8 deletions
|
@ -113,6 +113,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
|
|||
"pagerduty",
|
||||
"humanloop",
|
||||
"gcs_pubsub",
|
||||
"anthropic_cache_control_hook",
|
||||
]
|
||||
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
|
||||
_known_custom_logger_compatible_callbacks: List = list(
|
||||
|
|
|
@ -7,8 +7,9 @@ Users can define
|
|||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.types.integrations.anthropic_cache_control_hook import (
|
||||
CacheControlInjectionPoint,
|
||||
|
@ -64,8 +65,15 @@ class AnthropicCacheControlHook(CustomPromptManagement):
|
|||
control: ChatCompletionCachedContent = point.get(
|
||||
"control", None
|
||||
) or ChatCompletionCachedContent(type="ephemeral")
|
||||
targetted_index = point.get("index", None)
|
||||
targetted_index = point.get("index", None)
|
||||
|
||||
_targetted_index: Optional[Union[int, str]] = point.get("index", None)
|
||||
targetted_index: Optional[int] = None
|
||||
if isinstance(_targetted_index, str):
|
||||
if _targetted_index.isdigit():
|
||||
targetted_index = int(_targetted_index)
|
||||
else:
|
||||
targetted_index = _targetted_index
|
||||
|
||||
targetted_role = point.get("role", None)
|
||||
|
||||
# Case 1: Target by specific index
|
||||
|
@ -115,4 +123,28 @@ class AnthropicCacheControlHook(CustomPromptManagement):
|
|||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Return the integration name for this hook."""
|
||||
return "anthropic-cache-control-hook"
|
||||
return "anthropic_cache_control_hook"
|
||||
|
||||
@staticmethod
|
||||
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool:
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params: Dict,
|
||||
) -> Optional[CustomLogger]:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
return _init_custom_logger_compatible_class(
|
||||
logging_integration="anthropic_cache_control_hook",
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
return None
|
||||
|
|
|
@ -36,6 +36,7 @@ from litellm.cost_calculator import (
|
|||
RealtimeAPITokenUsageProcessor,
|
||||
_select_model_name_for_cost_calc,
|
||||
)
|
||||
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -465,7 +466,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
"""
|
||||
if prompt_id:
|
||||
return True
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -480,8 +483,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
custom_logger = (
|
||||
prompt_management_logger
|
||||
or self.get_custom_logger_for_prompt_management(model)
|
||||
or self.get_custom_logger_for_prompt_management(
|
||||
model=model, non_default_params=non_default_params
|
||||
)
|
||||
)
|
||||
|
||||
if custom_logger:
|
||||
(
|
||||
model,
|
||||
|
@ -499,7 +505,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
return model, messages, non_default_params
|
||||
|
||||
def get_custom_logger_for_prompt_management(
|
||||
self, model: str
|
||||
self, model: str, non_default_params: Dict
|
||||
) -> Optional[CustomLogger]:
|
||||
"""
|
||||
Get a custom logger for prompt management based on model name or available callbacks.
|
||||
|
@ -534,6 +540,26 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["prompt_integration"] = logger.__class__.__name__
|
||||
return logger
|
||||
|
||||
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
anthropic_cache_control_logger.__class__.__name__
|
||||
)
|
||||
return anthropic_cache_control_logger
|
||||
|
||||
return None
|
||||
|
||||
def get_custom_logger_for_anthropic_cache_control_hook(
|
||||
self, non_default_params: Dict
|
||||
) -> Optional[CustomLogger]:
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
custom_logger = _init_custom_logger_compatible_class(
|
||||
logging_integration="anthropic_cache_control_hook",
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
return custom_logger
|
||||
return None
|
||||
|
||||
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
|
||||
|
@ -2922,6 +2948,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args)
|
||||
_in_memory_loggers.append(pagerduty_logger)
|
||||
return pagerduty_logger # type: ignore
|
||||
elif logging_integration == "anthropic_cache_control_hook":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, AnthropicCacheControlHook):
|
||||
return callback
|
||||
anthropic_cache_control_hook = AnthropicCacheControlHook()
|
||||
_in_memory_loggers.append(anthropic_cache_control_hook)
|
||||
return anthropic_cache_control_hook # type: ignore
|
||||
elif logging_integration == "gcs_pubsub":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, GcsPubSubLogger):
|
||||
|
@ -3060,6 +3093,10 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
|
|||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, PagerDutyAlerting):
|
||||
return callback
|
||||
elif logging_integration == "anthropic_cache_control_hook":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, AnthropicCacheControlHook):
|
||||
return callback
|
||||
elif logging_integration == "gcs_pubsub":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, GcsPubSubLogger):
|
||||
|
|
|
@ -10,7 +10,7 @@ class CacheControlMessageInjectionPoint(TypedDict):
|
|||
role: Optional[
|
||||
Literal["user", "system", "assistant"]
|
||||
] # Optional: target by role (user, system, assistant)
|
||||
index: Optional[int] # Optional: target by specific index
|
||||
index: Optional[Union[int, str]] # Optional: target by specific index
|
||||
control: Optional[ChatCompletionCachedContent]
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue