diff --git a/litellm/__init__.py b/litellm/__init__.py index f27fa98029..2075f78820 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -113,6 +113,7 @@ _custom_logger_compatible_callbacks_literal = Literal[ "pagerduty", "humanloop", "gcs_pubsub", + "anthropic_cache_control_hook", ] logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None _known_custom_logger_compatible_callbacks: List = list( diff --git a/litellm/integrations/anthropic_cache_control_hook.py b/litellm/integrations/anthropic_cache_control_hook.py index f41d579cdf..c138b3cc25 100644 --- a/litellm/integrations/anthropic_cache_control_hook.py +++ b/litellm/integrations/anthropic_cache_control_hook.py @@ -7,8 +7,9 @@ Users can define """ import copy -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Dict, List, Optional, Tuple, Union, cast +from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_prompt_management import CustomPromptManagement from litellm.types.integrations.anthropic_cache_control_hook import ( CacheControlInjectionPoint, @@ -24,7 +25,7 @@ class AnthropicCacheControlHook(CustomPromptManagement): model: str, messages: List[AllMessageValues], non_default_params: dict, - prompt_id: str, + prompt_id: Optional[str], prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, ) -> Tuple[str, List[AllMessageValues], dict]: @@ -64,8 +65,15 @@ class AnthropicCacheControlHook(CustomPromptManagement): control: ChatCompletionCachedContent = point.get( "control", None ) or ChatCompletionCachedContent(type="ephemeral") - targetted_index = point.get("index", None) - targetted_index = point.get("index", None) + + _targetted_index: Optional[Union[int, str]] = point.get("index", None) + targetted_index: Optional[int] = None + if isinstance(_targetted_index, str): + if _targetted_index.isdigit(): + targetted_index = int(_targetted_index) + else: + targetted_index = _targetted_index + targetted_role = point.get("role", None) # Case 1: Target by specific index @@ -115,4 +123,28 @@ class AnthropicCacheControlHook(CustomPromptManagement): @property def integration_name(self) -> str: """Return the integration name for this hook.""" - return "anthropic-cache-control-hook" + return "anthropic_cache_control_hook" + + @staticmethod + def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool: + if non_default_params.get("cache_control_injection_points", None): + return True + return False + + @staticmethod + def get_custom_logger_for_anthropic_cache_control_hook( + non_default_params: Dict, + ) -> Optional[CustomLogger]: + from litellm.litellm_core_utils.litellm_logging import ( + _init_custom_logger_compatible_class, + ) + + if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook( + non_default_params + ): + return _init_custom_logger_compatible_class( + logging_integration="anthropic_cache_control_hook", + internal_usage_cache=None, + llm_router=None, + ) + return None diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index a7471f32f4..18cb8e8d7f 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -94,7 +94,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac model: str, messages: List[AllMessageValues], non_default_params: dict, - prompt_id: str, + prompt_id: Optional[str], prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, ) -> Tuple[str, List[AllMessageValues], dict]: diff --git a/litellm/integrations/custom_prompt_management.py b/litellm/integrations/custom_prompt_management.py index 5b34ef0c34..9d05e7b242 100644 --- a/litellm/integrations/custom_prompt_management.py +++ b/litellm/integrations/custom_prompt_management.py @@ -15,7 +15,7 @@ class CustomPromptManagement(CustomLogger, PromptManagementBase): model: str, messages: List[AllMessageValues], non_default_params: dict, - prompt_id: str, + prompt_id: Optional[str], prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, ) -> Tuple[str, List[AllMessageValues], dict]: diff --git a/litellm/integrations/humanloop.py b/litellm/integrations/humanloop.py index 4651238af4..853fbe148c 100644 --- a/litellm/integrations/humanloop.py +++ b/litellm/integrations/humanloop.py @@ -152,14 +152,21 @@ class HumanloopLogger(CustomLogger): model: str, messages: List[AllMessageValues], non_default_params: dict, - prompt_id: str, + prompt_id: Optional[str], prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Tuple[str, List[AllMessageValues], dict,]: + ) -> Tuple[ + str, + List[AllMessageValues], + dict, + ]: humanloop_api_key = dynamic_callback_params.get( "humanloop_api_key" ) or get_secret_str("HUMANLOOP_API_KEY") + if prompt_id is None: + raise ValueError("prompt_id is required for Humanloop integration") + if humanloop_api_key is None: return super().get_chat_completion_prompt( model=model, diff --git a/litellm/integrations/langfuse/langfuse_prompt_management.py b/litellm/integrations/langfuse/langfuse_prompt_management.py index 30f991ebd6..dcd3d9933a 100644 --- a/litellm/integrations/langfuse/langfuse_prompt_management.py +++ b/litellm/integrations/langfuse/langfuse_prompt_management.py @@ -169,10 +169,14 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge model: str, messages: List[AllMessageValues], non_default_params: dict, - prompt_id: str, + prompt_id: Optional[str], prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Tuple[str, List[AllMessageValues], dict,]: + ) -> Tuple[ + str, + List[AllMessageValues], + dict, + ]: return self.get_chat_completion_prompt( model, messages, diff --git a/litellm/integrations/prompt_management_base.py b/litellm/integrations/prompt_management_base.py index 07b6720ffe..270c34be8a 100644 --- a/litellm/integrations/prompt_management_base.py +++ b/litellm/integrations/prompt_management_base.py @@ -79,10 +79,12 @@ class PromptManagementBase(ABC): model: str, messages: List[AllMessageValues], non_default_params: dict, - prompt_id: str, + prompt_id: Optional[str], prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Tuple[str, List[AllMessageValues], dict,]: + ) -> Tuple[str, List[AllMessageValues], dict]: + if prompt_id is None: + raise ValueError("prompt_id is required for Prompt Management Base class") if not self.should_run_prompt_management( prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params ): diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index c5d59adca5..4d0b93f390 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -36,6 +36,7 @@ from litellm.cost_calculator import ( RealtimeAPITokenUsageProcessor, _select_model_name_for_cost_calc, ) +from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook from litellm.integrations.arize.arize import ArizeLogger from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger @@ -457,15 +458,17 @@ class Logging(LiteLLMLoggingBaseClass): def should_run_prompt_management_hooks( self, - prompt_id: str, non_default_params: Dict, + prompt_id: Optional[str] = None, ) -> bool: """ Return True if prompt management hooks should be run """ if prompt_id: return True - if non_default_params.get("cache_control_injection_points", None): + if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook( + non_default_params + ): return True return False @@ -473,15 +476,18 @@ class Logging(LiteLLMLoggingBaseClass): self, model: str, messages: List[AllMessageValues], - non_default_params: dict, - prompt_id: str, + non_default_params: Dict, + prompt_id: Optional[str], prompt_variables: Optional[dict], prompt_management_logger: Optional[CustomLogger] = None, ) -> Tuple[str, List[AllMessageValues], dict]: custom_logger = ( prompt_management_logger - or self.get_custom_logger_for_prompt_management(model) + or self.get_custom_logger_for_prompt_management( + model=model, non_default_params=non_default_params + ) ) + if custom_logger: ( model, @@ -490,7 +496,7 @@ class Logging(LiteLLMLoggingBaseClass): ) = custom_logger.get_chat_completion_prompt( model=model, messages=messages, - non_default_params=non_default_params, + non_default_params=non_default_params or {}, prompt_id=prompt_id, prompt_variables=prompt_variables, dynamic_callback_params=self.standard_callback_dynamic_params, @@ -499,7 +505,7 @@ class Logging(LiteLLMLoggingBaseClass): return model, messages, non_default_params def get_custom_logger_for_prompt_management( - self, model: str + self, model: str, non_default_params: Dict ) -> Optional[CustomLogger]: """ Get a custom logger for prompt management based on model name or available callbacks. @@ -534,6 +540,26 @@ class Logging(LiteLLMLoggingBaseClass): self.model_call_details["prompt_integration"] = logger.__class__.__name__ return logger + if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook( + non_default_params + ): + self.model_call_details["prompt_integration"] = ( + anthropic_cache_control_logger.__class__.__name__ + ) + return anthropic_cache_control_logger + + return None + + def get_custom_logger_for_anthropic_cache_control_hook( + self, non_default_params: Dict + ) -> Optional[CustomLogger]: + if non_default_params.get("cache_control_injection_points", None): + custom_logger = _init_custom_logger_compatible_class( + logging_integration="anthropic_cache_control_hook", + internal_usage_cache=None, + llm_router=None, + ) + return custom_logger return None def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict: @@ -2922,6 +2948,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args) _in_memory_loggers.append(pagerduty_logger) return pagerduty_logger # type: ignore + elif logging_integration == "anthropic_cache_control_hook": + for callback in _in_memory_loggers: + if isinstance(callback, AnthropicCacheControlHook): + return callback + anthropic_cache_control_hook = AnthropicCacheControlHook() + _in_memory_loggers.append(anthropic_cache_control_hook) + return anthropic_cache_control_hook # type: ignore elif logging_integration == "gcs_pubsub": for callback in _in_memory_loggers: if isinstance(callback, GcsPubSubLogger): @@ -3060,6 +3093,10 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 for callback in _in_memory_loggers: if isinstance(callback, PagerDutyAlerting): return callback + elif logging_integration == "anthropic_cache_control_hook": + for callback in _in_memory_loggers: + if isinstance(callback, AnthropicCacheControlHook): + return callback elif logging_integration == "gcs_pubsub": for callback in _in_memory_loggers: if isinstance(callback, GcsPubSubLogger): diff --git a/litellm/proxy/custom_prompt_management.py b/litellm/proxy/custom_prompt_management.py index 7f320ac00a..fc16f4a490 100644 --- a/litellm/proxy/custom_prompt_management.py +++ b/litellm/proxy/custom_prompt_management.py @@ -12,7 +12,7 @@ class X42PromptManagement(CustomPromptManagement): model: str, messages: List[AllMessageValues], non_default_params: dict, - prompt_id: str, + prompt_id: Optional[str], prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, ) -> Tuple[str, List[AllMessageValues], dict]: diff --git a/litellm/types/integrations/anthropic_cache_control_hook.py b/litellm/types/integrations/anthropic_cache_control_hook.py index edbd84a485..88f22ac1b7 100644 --- a/litellm/types/integrations/anthropic_cache_control_hook.py +++ b/litellm/types/integrations/anthropic_cache_control_hook.py @@ -10,7 +10,7 @@ class CacheControlMessageInjectionPoint(TypedDict): role: Optional[ Literal["user", "system", "assistant"] ] # Optional: target by role (user, system, assistant) - index: Optional[int] # Optional: target by specific index + index: Optional[Union[int, str]] # Optional: target by specific index control: Optional[ChatCompletionCachedContent] diff --git a/tests/litellm/integrations/test_custom_prompt_management.py b/tests/litellm/integrations/test_custom_prompt_management.py index 646b222e72..09ba32b203 100644 --- a/tests/litellm/integrations/test_custom_prompt_management.py +++ b/tests/litellm/integrations/test_custom_prompt_management.py @@ -30,7 +30,7 @@ class TestCustomPromptManagement(CustomPromptManagement): model: str, messages: List[AllMessageValues], non_default_params: dict, - prompt_id: str, + prompt_id: Optional[str], prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, ) -> Tuple[str, List[AllMessageValues], dict]: diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index fcba3ebbc3..3ffa97be9f 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -33,6 +33,7 @@ from litellm.integrations.opik.opik import OpikLogger from litellm.integrations.opentelemetry import OpenTelemetry from litellm.integrations.mlflow import MlflowLogger from litellm.integrations.argilla import ArgillaLogger +from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook from litellm.integrations.langfuse.langfuse_prompt_management import ( LangfusePromptManagement, ) @@ -73,6 +74,7 @@ callback_class_str_to_classType = { "otel": OpenTelemetry, "pagerduty": PagerDutyAlerting, "gcs_pubsub": GcsPubSubLogger, + "anthropic_cache_control_hook": AnthropicCacheControlHook, } expected_env_vars = { diff --git a/ui/litellm-dashboard/src/components/add_model/advanced_settings.tsx b/ui/litellm-dashboard/src/components/add_model/advanced_settings.tsx index b887f58fd5..7a1cb93f4c 100644 --- a/ui/litellm-dashboard/src/components/add_model/advanced_settings.tsx +++ b/ui/litellm-dashboard/src/components/add_model/advanced_settings.tsx @@ -5,6 +5,7 @@ import { Row, Col, Typography, Card } from "antd"; import TextArea from "antd/es/input/TextArea"; import { Team } from "../key_team_helpers/key_list"; import TeamDropdown from "../common_components/team_dropdown"; +import CacheControlSettings from "./cache_control_settings"; const { Link } = Typography; interface AdvancedSettingsProps { @@ -21,6 +22,7 @@ const AdvancedSettings: React.FC = ({ const [form] = Form.useForm(); const [customPricing, setCustomPricing] = React.useState(false); const [pricingModel, setPricingModel] = React.useState<'per_token' | 'per_second'>('per_token'); + const [showCacheControl, setShowCacheControl] = React.useState(false); // Add validation function for numbers const validateNumber = (_: any, value: string) => { @@ -83,6 +85,24 @@ const AdvancedSettings: React.FC = ({ } }; + const handleCacheControlChange = (checked: boolean) => { + setShowCacheControl(checked); + if (!checked) { + const currentParams = form.getFieldValue('litellm_extra_params'); + try { + let paramsObj = currentParams ? JSON.parse(currentParams) : {}; + delete paramsObj.cache_control_injection_points; + if (Object.keys(paramsObj).length > 0) { + form.setFieldValue('litellm_extra_params', JSON.stringify(paramsObj, null, 2)); + } else { + form.setFieldValue('litellm_extra_params', ''); + } + } catch (error) { + form.setFieldValue('litellm_extra_params', ''); + } + } + }; + return ( <> @@ -150,6 +170,12 @@ const AdvancedSettings: React.FC = ({ )} + + void; +} + +const CacheControlSettings: React.FC = ({ + form, + showCacheControl, + onCacheControlChange, +}) => { + const updateCacheControlPoints = (injectionPoints: CacheControlInjectionPoint[]) => { + const currentParams = form.getFieldValue('litellm_extra_params'); + try { + let paramsObj = currentParams ? JSON.parse(currentParams) : {}; + if (injectionPoints.length > 0) { + paramsObj.cache_control_injection_points = injectionPoints; + } else { + delete paramsObj.cache_control_injection_points; + } + if (Object.keys(paramsObj).length > 0) { + form.setFieldValue('litellm_extra_params', JSON.stringify(paramsObj, null, 2)); + } else { + form.setFieldValue('litellm_extra_params', ''); + } + } catch (error) { + console.error('Error updating cache control points:', error); + } + }; + + return ( + <> + + + + + {showCacheControl && ( +
+ + Specify either a role (to cache all messages of that role) or a specific message index. + If both are provided, the index takes precedence. + + + + {(fields, { add, remove }) => ( + <> + {fields.map((field, index) => ( +
+ + { + const values = form.getFieldValue('cache_control_points'); + updateCacheControlPoints(values); + }} + /> + + + + { + const values = form.getFieldValue('cache_control_points'); + updateCacheControlPoints(values); + }} + /> + + + {fields.length > 1 && ( + { + remove(field.name); + setTimeout(() => { + const values = form.getFieldValue('cache_control_points'); + updateCacheControlPoints(values); + }, 0); + }} + /> + )} +
+ ))} + + + + + + )} +
+
+ )} + + ); +}; + +export default CacheControlSettings; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/add_model/handle_add_model_submit.tsx b/ui/litellm-dashboard/src/components/add_model/handle_add_model_submit.tsx index f71ff1fe69..5924d6beda 100644 --- a/ui/litellm-dashboard/src/components/add_model/handle_add_model_submit.tsx +++ b/ui/litellm-dashboard/src/components/add_model/handle_add_model_submit.tsx @@ -60,7 +60,7 @@ export const prepareModelAddRequest = async ( continue; } // Skip the custom_pricing and pricing_model fields as they're only used for UI control - if (key === 'custom_pricing' || key === 'pricing_model') { + if (key === 'custom_pricing' || key === 'pricing_model' || key === 'cache_control') { continue; } if (key == "model_name") { diff --git a/ui/litellm-dashboard/src/components/model_info_view.tsx b/ui/litellm-dashboard/src/components/model_info_view.tsx index 22c4fefc99..be4a4aa40f 100644 --- a/ui/litellm-dashboard/src/components/model_info_view.tsx +++ b/ui/litellm-dashboard/src/components/model_info_view.tsx @@ -23,6 +23,7 @@ import { getProviderLogoAndName } from "./provider_info_helpers"; import { getDisplayModelName } from "./view_model/model_name_display"; import AddCredentialsModal from "./model_add/add_credentials_tab"; import ReuseCredentialsModal from "./model_add/reuse_credentials"; +import CacheControlSettings from "./add_model/cache_control_settings"; interface ModelInfoViewProps { modelId: string; @@ -57,6 +58,7 @@ export default function ModelInfoView({ const [isSaving, setIsSaving] = useState(false); const [isEditing, setIsEditing] = useState(false); const [existingCredential, setExistingCredential] = useState(null); + const [showCacheControl, setShowCacheControl] = useState(false); const canEditModel = userRole === "Admin" || modelData.model_info.created_by === userID; const isAdmin = userRole === "Admin"; @@ -86,6 +88,11 @@ export default function ModelInfoView({ console.log("modelInfoResponse, ", modelInfoResponse); let specificModelData = modelInfoResponse.data[0]; setLocalModelData(specificModelData); + + // Check if cache control is enabled + if (specificModelData?.litellm_params?.cache_control_injection_points) { + setShowCacheControl(true); + } } getExistingCredential(); getModelInfo(); @@ -112,22 +119,31 @@ export default function ModelInfoView({ if (!accessToken) return; setIsSaving(true); + let updatedLitellmParams = { + ...localModelData.litellm_params, + model: values.litellm_model_name, + api_base: values.api_base, + custom_llm_provider: values.custom_llm_provider, + organization: values.organization, + tpm: values.tpm, + rpm: values.rpm, + max_retries: values.max_retries, + timeout: values.timeout, + stream_timeout: values.stream_timeout, + input_cost_per_token: values.input_cost / 1_000_000, + output_cost_per_token: values.output_cost / 1_000_000, + }; + + // Handle cache control settings + if (values.cache_control && values.cache_control_injection_points?.length > 0) { + updatedLitellmParams.cache_control_injection_points = values.cache_control_injection_points; + } else { + delete updatedLitellmParams.cache_control_injection_points; + } + const updateData = { model_name: values.model_name, - litellm_params: { - ...localModelData.litellm_params, - model: values.litellm_model_name, - api_base: values.api_base, - custom_llm_provider: values.custom_llm_provider, - organization: values.organization, - tpm: values.tpm, - rpm: values.rpm, - max_retries: values.max_retries, - timeout: values.timeout, - stream_timeout: values.stream_timeout, - input_cost_per_token: values.input_cost / 1_000_000, - output_cost_per_token: values.output_cost / 1_000_000, - }, + litellm_params: updatedLitellmParams, model_info: { id: modelId, } @@ -139,7 +155,7 @@ export default function ModelInfoView({ ...localModelData, model_name: values.model_name, litellm_model_name: values.litellm_model_name, - litellm_params: updateData.litellm_params + litellm_params: updatedLitellmParams }; setLocalModelData(updatedModelData); @@ -337,6 +353,8 @@ export default function ModelInfoView({ (localModelData.litellm_params.input_cost_per_token * 1_000_000) : localModelData.model_info?.input_cost_per_token * 1_000_000 || null, output_cost: localModelData.litellm_params?.output_cost_per_token ? (localModelData.litellm_params.output_cost_per_token * 1_000_000) : localModelData.model_info?.output_cost_per_token * 1_000_000 || null, + cache_control: localModelData.litellm_params?.cache_control_injection_points ? true : false, + cache_control_injection_points: localModelData.litellm_params?.cache_control_injection_points || [], }} layout="vertical" onValuesChange={() => setIsDirty(true)} @@ -499,6 +517,37 @@ export default function ModelInfoView({ )} + {/* Cache Control Section */} + {isEditing ? ( + setShowCacheControl(checked)} + /> + ) : ( +
+ Cache Control +
+ {localModelData.litellm_params?.cache_control_injection_points ? ( +
+

Enabled

+
+ {localModelData.litellm_params.cache_control_injection_points.map((point: any, i: number) => ( +
+ Location: {point.location}, + {point.role && Role: {point.role}} + {point.index !== undefined && Index: {point.index}} +
+ ))} +
+
+ ) : ( + "Disabled" + )} +
+
+ )} +
Team ID