[UI] Allow setting prompt cache_control_injection_points (#10000)

* test_anthropic_cache_control_hook_system_message

* test_anthropic_cache_control_hook.py

* should_run_prompt_management_hooks

* fix should_run_prompt_management_hooks

* test_anthropic_cache_control_hook_specific_index

* fix test

* fix linting errors

* ChatCompletionCachedContent

* initial commit for cache control

* fixes ui design

* fix inserting cache_control_injection_points

* fix entering cache control points

* fixes for using cache control on ui + backend

* update cache control settings on edit model page

* fix init custom logger compatible class

* fix linting errors

* fix linting errors

* fix get_chat_completion_prompt
This commit is contained in:
Ishaan Jaff 2025-04-14 21:17:42 -07:00 committed by GitHub
parent 6cfa50d278
commit c1a642ce20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 358 additions and 39 deletions

View file

@ -113,6 +113,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"pagerduty",
"humanloop",
"gcs_pubsub",
"anthropic_cache_control_hook",
]
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
_known_custom_logger_compatible_callbacks: List = list(

View file

@ -7,8 +7,9 @@ Users can define
"""
import copy
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Dict, List, Optional, Tuple, Union, cast
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.integrations.anthropic_cache_control_hook import (
CacheControlInjectionPoint,
@ -24,7 +25,7 @@ class AnthropicCacheControlHook(CustomPromptManagement):
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:
@ -64,8 +65,15 @@ class AnthropicCacheControlHook(CustomPromptManagement):
control: ChatCompletionCachedContent = point.get(
"control", None
) or ChatCompletionCachedContent(type="ephemeral")
targetted_index = point.get("index", None)
targetted_index = point.get("index", None)
_targetted_index: Optional[Union[int, str]] = point.get("index", None)
targetted_index: Optional[int] = None
if isinstance(_targetted_index, str):
if _targetted_index.isdigit():
targetted_index = int(_targetted_index)
else:
targetted_index = _targetted_index
targetted_role = point.get("role", None)
# Case 1: Target by specific index
@ -115,4 +123,28 @@ class AnthropicCacheControlHook(CustomPromptManagement):
@property
def integration_name(self) -> str:
"""Return the integration name for this hook."""
return "anthropic-cache-control-hook"
return "anthropic_cache_control_hook"
@staticmethod
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool:
if non_default_params.get("cache_control_injection_points", None):
return True
return False
@staticmethod
def get_custom_logger_for_anthropic_cache_control_hook(
non_default_params: Dict,
) -> Optional[CustomLogger]:
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
non_default_params
):
return _init_custom_logger_compatible_class(
logging_integration="anthropic_cache_control_hook",
internal_usage_cache=None,
llm_router=None,
)
return None

View file

@ -94,7 +94,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:

View file

@ -15,7 +15,7 @@ class CustomPromptManagement(CustomLogger, PromptManagementBase):
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:

View file

@ -152,14 +152,21 @@ class HumanloopLogger(CustomLogger):
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict,]:
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
humanloop_api_key = dynamic_callback_params.get(
"humanloop_api_key"
) or get_secret_str("HUMANLOOP_API_KEY")
if prompt_id is None:
raise ValueError("prompt_id is required for Humanloop integration")
if humanloop_api_key is None:
return super().get_chat_completion_prompt(
model=model,

View file

@ -169,10 +169,14 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict,]:
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
return self.get_chat_completion_prompt(
model,
messages,

View file

@ -79,10 +79,12 @@ class PromptManagementBase(ABC):
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict,]:
) -> Tuple[str, List[AllMessageValues], dict]:
if prompt_id is None:
raise ValueError("prompt_id is required for Prompt Management Base class")
if not self.should_run_prompt_management(
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
):

View file

@ -36,6 +36,7 @@ from litellm.cost_calculator import (
RealtimeAPITokenUsageProcessor,
_select_model_name_for_cost_calc,
)
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
@ -457,15 +458,17 @@ class Logging(LiteLLMLoggingBaseClass):
def should_run_prompt_management_hooks(
self,
prompt_id: str,
non_default_params: Dict,
prompt_id: Optional[str] = None,
) -> bool:
"""
Return True if prompt management hooks should be run
"""
if prompt_id:
return True
if non_default_params.get("cache_control_injection_points", None):
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
non_default_params
):
return True
return False
@ -473,15 +476,18 @@ class Logging(LiteLLMLoggingBaseClass):
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
non_default_params: Dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
prompt_management_logger: Optional[CustomLogger] = None,
) -> Tuple[str, List[AllMessageValues], dict]:
custom_logger = (
prompt_management_logger
or self.get_custom_logger_for_prompt_management(model)
or self.get_custom_logger_for_prompt_management(
model=model, non_default_params=non_default_params
)
)
if custom_logger:
(
model,
@ -490,7 +496,7 @@ class Logging(LiteLLMLoggingBaseClass):
) = custom_logger.get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
non_default_params=non_default_params or {},
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=self.standard_callback_dynamic_params,
@ -499,7 +505,7 @@ class Logging(LiteLLMLoggingBaseClass):
return model, messages, non_default_params
def get_custom_logger_for_prompt_management(
self, model: str
self, model: str, non_default_params: Dict
) -> Optional[CustomLogger]:
"""
Get a custom logger for prompt management based on model name or available callbacks.
@ -534,6 +540,26 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["prompt_integration"] = logger.__class__.__name__
return logger
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
non_default_params
):
self.model_call_details["prompt_integration"] = (
anthropic_cache_control_logger.__class__.__name__
)
return anthropic_cache_control_logger
return None
def get_custom_logger_for_anthropic_cache_control_hook(
self, non_default_params: Dict
) -> Optional[CustomLogger]:
if non_default_params.get("cache_control_injection_points", None):
custom_logger = _init_custom_logger_compatible_class(
logging_integration="anthropic_cache_control_hook",
internal_usage_cache=None,
llm_router=None,
)
return custom_logger
return None
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
@ -2922,6 +2948,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args)
_in_memory_loggers.append(pagerduty_logger)
return pagerduty_logger # type: ignore
elif logging_integration == "anthropic_cache_control_hook":
for callback in _in_memory_loggers:
if isinstance(callback, AnthropicCacheControlHook):
return callback
anthropic_cache_control_hook = AnthropicCacheControlHook()
_in_memory_loggers.append(anthropic_cache_control_hook)
return anthropic_cache_control_hook # type: ignore
elif logging_integration == "gcs_pubsub":
for callback in _in_memory_loggers:
if isinstance(callback, GcsPubSubLogger):
@ -3060,6 +3093,10 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
for callback in _in_memory_loggers:
if isinstance(callback, PagerDutyAlerting):
return callback
elif logging_integration == "anthropic_cache_control_hook":
for callback in _in_memory_loggers:
if isinstance(callback, AnthropicCacheControlHook):
return callback
elif logging_integration == "gcs_pubsub":
for callback in _in_memory_loggers:
if isinstance(callback, GcsPubSubLogger):

View file

@ -12,7 +12,7 @@ class X42PromptManagement(CustomPromptManagement):
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:

View file

@ -10,7 +10,7 @@ class CacheControlMessageInjectionPoint(TypedDict):
role: Optional[
Literal["user", "system", "assistant"]
] # Optional: target by role (user, system, assistant)
index: Optional[int] # Optional: target by specific index
index: Optional[Union[int, str]] # Optional: target by specific index
control: Optional[ChatCompletionCachedContent]

View file

@ -30,7 +30,7 @@ class TestCustomPromptManagement(CustomPromptManagement):
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:

View file

@ -33,6 +33,7 @@ from litellm.integrations.opik.opik import OpikLogger
from litellm.integrations.opentelemetry import OpenTelemetry
from litellm.integrations.mlflow import MlflowLogger
from litellm.integrations.argilla import ArgillaLogger
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
@ -73,6 +74,7 @@ callback_class_str_to_classType = {
"otel": OpenTelemetry,
"pagerduty": PagerDutyAlerting,
"gcs_pubsub": GcsPubSubLogger,
"anthropic_cache_control_hook": AnthropicCacheControlHook,
}
expected_env_vars = {

View file

@ -5,6 +5,7 @@ import { Row, Col, Typography, Card } from "antd";
import TextArea from "antd/es/input/TextArea";
import { Team } from "../key_team_helpers/key_list";
import TeamDropdown from "../common_components/team_dropdown";
import CacheControlSettings from "./cache_control_settings";
const { Link } = Typography;
interface AdvancedSettingsProps {
@ -21,6 +22,7 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
const [form] = Form.useForm();
const [customPricing, setCustomPricing] = React.useState(false);
const [pricingModel, setPricingModel] = React.useState<'per_token' | 'per_second'>('per_token');
const [showCacheControl, setShowCacheControl] = React.useState(false);
// Add validation function for numbers
const validateNumber = (_: any, value: string) => {
@ -83,6 +85,24 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
}
};
const handleCacheControlChange = (checked: boolean) => {
setShowCacheControl(checked);
if (!checked) {
const currentParams = form.getFieldValue('litellm_extra_params');
try {
let paramsObj = currentParams ? JSON.parse(currentParams) : {};
delete paramsObj.cache_control_injection_points;
if (Object.keys(paramsObj).length > 0) {
form.setFieldValue('litellm_extra_params', JSON.stringify(paramsObj, null, 2));
} else {
form.setFieldValue('litellm_extra_params', '');
}
} catch (error) {
form.setFieldValue('litellm_extra_params', '');
}
}
};
return (
<>
<Accordion className="mt-2 mb-4">
@ -150,6 +170,12 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
</div>
)}
<CacheControlSettings
form={form}
showCacheControl={showCacheControl}
onCacheControlChange={handleCacheControlChange}
/>
<Form.Item
label="Use in pass through routes"
name="use_in_pass_through"

View file

@ -0,0 +1,159 @@
import React from "react";
import { Form, Switch, Select, Input, Typography } from "antd";
import { PlusOutlined, MinusCircleOutlined } from '@ant-design/icons';
import NumericalInput from "../shared/numerical_input";
const { Text } = Typography;
interface CacheControlInjectionPoint {
location: "message";
role?: "user" | "system" | "assistant";
index?: number;
}
interface CacheControlSettingsProps {
form: any; // Form instance from parent
showCacheControl: boolean;
onCacheControlChange: (checked: boolean) => void;
}
const CacheControlSettings: React.FC<CacheControlSettingsProps> = ({
form,
showCacheControl,
onCacheControlChange,
}) => {
const updateCacheControlPoints = (injectionPoints: CacheControlInjectionPoint[]) => {
const currentParams = form.getFieldValue('litellm_extra_params');
try {
let paramsObj = currentParams ? JSON.parse(currentParams) : {};
if (injectionPoints.length > 0) {
paramsObj.cache_control_injection_points = injectionPoints;
} else {
delete paramsObj.cache_control_injection_points;
}
if (Object.keys(paramsObj).length > 0) {
form.setFieldValue('litellm_extra_params', JSON.stringify(paramsObj, null, 2));
} else {
form.setFieldValue('litellm_extra_params', '');
}
} catch (error) {
console.error('Error updating cache control points:', error);
}
};
return (
<>
<Form.Item
label="Cache Control"
name="cache_control"
valuePropName="checked"
className="mb-4"
tooltip="Tell litellm where to inject cache control checkpoints. You can specify either by role (to apply to all messages of that role) or by specific message index."
>
<Switch onChange={onCacheControlChange} className="bg-gray-600" />
</Form.Item>
{showCacheControl && (
<div className="ml-6 pl-4 border-l-2 border-gray-200">
<Text className="text-sm text-gray-500 block mb-4">
Specify either a role (to cache all messages of that role) or a specific message index.
If both are provided, the index takes precedence.
</Text>
<Form.List
name="cache_control_injection_points"
initialValue={[{ location: "message" }]}
>
{(fields, { add, remove }) => (
<>
{fields.map((field, index) => (
<div key={field.key} className="flex items-center mb-4 gap-4">
<Form.Item
{...field}
label="Type"
name={[field.name, 'location']}
initialValue="message"
className="mb-0"
style={{ width: '180px' }}
>
<Select disabled options={[{ value: 'message', label: 'Message' }]} />
</Form.Item>
<Form.Item
{...field}
label="Role"
name={[field.name, 'role']}
className="mb-0"
style={{ width: '180px' }}
tooltip="Select a role to cache all messages of this type"
>
<Select
placeholder="Select a role"
allowClear
options={[
{ value: 'user', label: 'User' },
{ value: 'system', label: 'System' },
{ value: 'assistant', label: 'Assistant' },
]}
onChange={() => {
const values = form.getFieldValue('cache_control_points');
updateCacheControlPoints(values);
}}
/>
</Form.Item>
<Form.Item
{...field}
label="Index"
name={[field.name, 'index']}
className="mb-0"
style={{ width: '180px' }}
tooltip="Specify a specific message index (optional)"
>
<NumericalInput
type="number"
placeholder="Optional"
step={1}
min={0}
onChange={() => {
const values = form.getFieldValue('cache_control_points');
updateCacheControlPoints(values);
}}
/>
</Form.Item>
{fields.length > 1 && (
<MinusCircleOutlined
className="text-red-500 cursor-pointer text-lg mt-8"
onClick={() => {
remove(field.name);
setTimeout(() => {
const values = form.getFieldValue('cache_control_points');
updateCacheControlPoints(values);
}, 0);
}}
/>
)}
</div>
))}
<Form.Item>
<button
type="button"
className="flex items-center justify-center w-full border border-dashed border-gray-300 py-2 px-4 text-gray-600 hover:text-blue-600 hover:border-blue-300 transition-all rounded"
onClick={() => add()}
>
<PlusOutlined className="mr-2" />
Add Injection Point
</button>
</Form.Item>
</>
)}
</Form.List>
</div>
)}
</>
);
};
export default CacheControlSettings;

View file

@ -60,7 +60,7 @@ export const prepareModelAddRequest = async (
continue;
}
// Skip the custom_pricing and pricing_model fields as they're only used for UI control
if (key === 'custom_pricing' || key === 'pricing_model') {
if (key === 'custom_pricing' || key === 'pricing_model' || key === 'cache_control') {
continue;
}
if (key == "model_name") {

View file

@ -23,6 +23,7 @@ import { getProviderLogoAndName } from "./provider_info_helpers";
import { getDisplayModelName } from "./view_model/model_name_display";
import AddCredentialsModal from "./model_add/add_credentials_tab";
import ReuseCredentialsModal from "./model_add/reuse_credentials";
import CacheControlSettings from "./add_model/cache_control_settings";
interface ModelInfoViewProps {
modelId: string;
@ -57,6 +58,7 @@ export default function ModelInfoView({
const [isSaving, setIsSaving] = useState(false);
const [isEditing, setIsEditing] = useState(false);
const [existingCredential, setExistingCredential] = useState<CredentialItem | null>(null);
const [showCacheControl, setShowCacheControl] = useState(false);
const canEditModel = userRole === "Admin" || modelData.model_info.created_by === userID;
const isAdmin = userRole === "Admin";
@ -86,6 +88,11 @@ export default function ModelInfoView({
console.log("modelInfoResponse, ", modelInfoResponse);
let specificModelData = modelInfoResponse.data[0];
setLocalModelData(specificModelData);
// Check if cache control is enabled
if (specificModelData?.litellm_params?.cache_control_injection_points) {
setShowCacheControl(true);
}
}
getExistingCredential();
getModelInfo();
@ -112,22 +119,31 @@ export default function ModelInfoView({
if (!accessToken) return;
setIsSaving(true);
let updatedLitellmParams = {
...localModelData.litellm_params,
model: values.litellm_model_name,
api_base: values.api_base,
custom_llm_provider: values.custom_llm_provider,
organization: values.organization,
tpm: values.tpm,
rpm: values.rpm,
max_retries: values.max_retries,
timeout: values.timeout,
stream_timeout: values.stream_timeout,
input_cost_per_token: values.input_cost / 1_000_000,
output_cost_per_token: values.output_cost / 1_000_000,
};
// Handle cache control settings
if (values.cache_control && values.cache_control_injection_points?.length > 0) {
updatedLitellmParams.cache_control_injection_points = values.cache_control_injection_points;
} else {
delete updatedLitellmParams.cache_control_injection_points;
}
const updateData = {
model_name: values.model_name,
litellm_params: {
...localModelData.litellm_params,
model: values.litellm_model_name,
api_base: values.api_base,
custom_llm_provider: values.custom_llm_provider,
organization: values.organization,
tpm: values.tpm,
rpm: values.rpm,
max_retries: values.max_retries,
timeout: values.timeout,
stream_timeout: values.stream_timeout,
input_cost_per_token: values.input_cost / 1_000_000,
output_cost_per_token: values.output_cost / 1_000_000,
},
litellm_params: updatedLitellmParams,
model_info: {
id: modelId,
}
@ -139,7 +155,7 @@ export default function ModelInfoView({
...localModelData,
model_name: values.model_name,
litellm_model_name: values.litellm_model_name,
litellm_params: updateData.litellm_params
litellm_params: updatedLitellmParams
};
setLocalModelData(updatedModelData);
@ -337,6 +353,8 @@ export default function ModelInfoView({
(localModelData.litellm_params.input_cost_per_token * 1_000_000) : localModelData.model_info?.input_cost_per_token * 1_000_000 || null,
output_cost: localModelData.litellm_params?.output_cost_per_token ?
(localModelData.litellm_params.output_cost_per_token * 1_000_000) : localModelData.model_info?.output_cost_per_token * 1_000_000 || null,
cache_control: localModelData.litellm_params?.cache_control_injection_points ? true : false,
cache_control_injection_points: localModelData.litellm_params?.cache_control_injection_points || [],
}}
layout="vertical"
onValuesChange={() => setIsDirty(true)}
@ -499,6 +517,37 @@ export default function ModelInfoView({
)}
</div>
{/* Cache Control Section */}
{isEditing ? (
<CacheControlSettings
form={form}
showCacheControl={showCacheControl}
onCacheControlChange={(checked) => setShowCacheControl(checked)}
/>
) : (
<div>
<Text className="font-medium">Cache Control</Text>
<div className="mt-1 p-2 bg-gray-50 rounded">
{localModelData.litellm_params?.cache_control_injection_points ? (
<div>
<p>Enabled</p>
<div className="mt-2">
{localModelData.litellm_params.cache_control_injection_points.map((point: any, i: number) => (
<div key={i} className="text-sm text-gray-600 mb-1">
Location: {point.location},
{point.role && <span> Role: {point.role}</span>}
{point.index !== undefined && <span> Index: {point.index}</span>}
</div>
))}
</div>
</div>
) : (
"Disabled"
)}
</div>
</div>
)}
<div>
<Text className="font-medium">Team ID</Text>
<div className="mt-1 p-2 bg-gray-50 rounded">