mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
[UI] Allow setting prompt cache_control_injection_points
(#10000)
* test_anthropic_cache_control_hook_system_message * test_anthropic_cache_control_hook.py * should_run_prompt_management_hooks * fix should_run_prompt_management_hooks * test_anthropic_cache_control_hook_specific_index * fix test * fix linting errors * ChatCompletionCachedContent * initial commit for cache control * fixes ui design * fix inserting cache_control_injection_points * fix entering cache control points * fixes for using cache control on ui + backend * update cache control settings on edit model page * fix init custom logger compatible class * fix linting errors * fix linting errors * fix get_chat_completion_prompt
This commit is contained in:
parent
6cfa50d278
commit
c1a642ce20
16 changed files with 358 additions and 39 deletions
|
@ -113,6 +113,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
|
||||||
"pagerduty",
|
"pagerduty",
|
||||||
"humanloop",
|
"humanloop",
|
||||||
"gcs_pubsub",
|
"gcs_pubsub",
|
||||||
|
"anthropic_cache_control_hook",
|
||||||
]
|
]
|
||||||
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
|
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
|
||||||
_known_custom_logger_compatible_callbacks: List = list(
|
_known_custom_logger_compatible_callbacks: List = list(
|
||||||
|
|
|
@ -7,8 +7,9 @@ Users can define
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
from typing import Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||||
from litellm.types.integrations.anthropic_cache_control_hook import (
|
from litellm.types.integrations.anthropic_cache_control_hook import (
|
||||||
CacheControlInjectionPoint,
|
CacheControlInjectionPoint,
|
||||||
|
@ -24,7 +25,7 @@ class AnthropicCacheControlHook(CustomPromptManagement):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: Optional[str],
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
|
@ -64,8 +65,15 @@ class AnthropicCacheControlHook(CustomPromptManagement):
|
||||||
control: ChatCompletionCachedContent = point.get(
|
control: ChatCompletionCachedContent = point.get(
|
||||||
"control", None
|
"control", None
|
||||||
) or ChatCompletionCachedContent(type="ephemeral")
|
) or ChatCompletionCachedContent(type="ephemeral")
|
||||||
targetted_index = point.get("index", None)
|
|
||||||
targetted_index = point.get("index", None)
|
_targetted_index: Optional[Union[int, str]] = point.get("index", None)
|
||||||
|
targetted_index: Optional[int] = None
|
||||||
|
if isinstance(_targetted_index, str):
|
||||||
|
if _targetted_index.isdigit():
|
||||||
|
targetted_index = int(_targetted_index)
|
||||||
|
else:
|
||||||
|
targetted_index = _targetted_index
|
||||||
|
|
||||||
targetted_role = point.get("role", None)
|
targetted_role = point.get("role", None)
|
||||||
|
|
||||||
# Case 1: Target by specific index
|
# Case 1: Target by specific index
|
||||||
|
@ -115,4 +123,28 @@ class AnthropicCacheControlHook(CustomPromptManagement):
|
||||||
@property
|
@property
|
||||||
def integration_name(self) -> str:
|
def integration_name(self) -> str:
|
||||||
"""Return the integration name for this hook."""
|
"""Return the integration name for this hook."""
|
||||||
return "anthropic-cache-control-hook"
|
return "anthropic_cache_control_hook"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool:
|
||||||
|
if non_default_params.get("cache_control_injection_points", None):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_custom_logger_for_anthropic_cache_control_hook(
|
||||||
|
non_default_params: Dict,
|
||||||
|
) -> Optional[CustomLogger]:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
|
_init_custom_logger_compatible_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
|
||||||
|
non_default_params
|
||||||
|
):
|
||||||
|
return _init_custom_logger_compatible_class(
|
||||||
|
logging_integration="anthropic_cache_control_hook",
|
||||||
|
internal_usage_cache=None,
|
||||||
|
llm_router=None,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
|
@ -94,7 +94,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: Optional[str],
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
|
|
|
@ -15,7 +15,7 @@ class CustomPromptManagement(CustomLogger, PromptManagementBase):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: Optional[str],
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
|
|
|
@ -152,14 +152,21 @@ class HumanloopLogger(CustomLogger):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: Optional[str],
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
) -> Tuple[
|
||||||
|
str,
|
||||||
|
List[AllMessageValues],
|
||||||
|
dict,
|
||||||
|
]:
|
||||||
humanloop_api_key = dynamic_callback_params.get(
|
humanloop_api_key = dynamic_callback_params.get(
|
||||||
"humanloop_api_key"
|
"humanloop_api_key"
|
||||||
) or get_secret_str("HUMANLOOP_API_KEY")
|
) or get_secret_str("HUMANLOOP_API_KEY")
|
||||||
|
|
||||||
|
if prompt_id is None:
|
||||||
|
raise ValueError("prompt_id is required for Humanloop integration")
|
||||||
|
|
||||||
if humanloop_api_key is None:
|
if humanloop_api_key is None:
|
||||||
return super().get_chat_completion_prompt(
|
return super().get_chat_completion_prompt(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -169,10 +169,14 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: Optional[str],
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
) -> Tuple[
|
||||||
|
str,
|
||||||
|
List[AllMessageValues],
|
||||||
|
dict,
|
||||||
|
]:
|
||||||
return self.get_chat_completion_prompt(
|
return self.get_chat_completion_prompt(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
|
|
|
@ -79,10 +79,12 @@ class PromptManagementBase(ABC):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: Optional[str],
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
|
if prompt_id is None:
|
||||||
|
raise ValueError("prompt_id is required for Prompt Management Base class")
|
||||||
if not self.should_run_prompt_management(
|
if not self.should_run_prompt_management(
|
||||||
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
|
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
|
||||||
):
|
):
|
||||||
|
|
|
@ -36,6 +36,7 @@ from litellm.cost_calculator import (
|
||||||
RealtimeAPITokenUsageProcessor,
|
RealtimeAPITokenUsageProcessor,
|
||||||
_select_model_name_for_cost_calc,
|
_select_model_name_for_cost_calc,
|
||||||
)
|
)
|
||||||
|
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
|
||||||
from litellm.integrations.arize.arize import ArizeLogger
|
from litellm.integrations.arize.arize import ArizeLogger
|
||||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -457,15 +458,17 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
|
|
||||||
def should_run_prompt_management_hooks(
|
def should_run_prompt_management_hooks(
|
||||||
self,
|
self,
|
||||||
prompt_id: str,
|
|
||||||
non_default_params: Dict,
|
non_default_params: Dict,
|
||||||
|
prompt_id: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Return True if prompt management hooks should be run
|
Return True if prompt management hooks should be run
|
||||||
"""
|
"""
|
||||||
if prompt_id:
|
if prompt_id:
|
||||||
return True
|
return True
|
||||||
if non_default_params.get("cache_control_injection_points", None):
|
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
|
||||||
|
non_default_params
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -473,15 +476,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
non_default_params: dict,
|
non_default_params: Dict,
|
||||||
prompt_id: str,
|
prompt_id: Optional[str],
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
prompt_management_logger: Optional[CustomLogger] = None,
|
prompt_management_logger: Optional[CustomLogger] = None,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
custom_logger = (
|
custom_logger = (
|
||||||
prompt_management_logger
|
prompt_management_logger
|
||||||
or self.get_custom_logger_for_prompt_management(model)
|
or self.get_custom_logger_for_prompt_management(
|
||||||
|
model=model, non_default_params=non_default_params
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_logger:
|
if custom_logger:
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
|
@ -490,7 +496,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
) = custom_logger.get_chat_completion_prompt(
|
) = custom_logger.get_chat_completion_prompt(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params or {},
|
||||||
prompt_id=prompt_id,
|
prompt_id=prompt_id,
|
||||||
prompt_variables=prompt_variables,
|
prompt_variables=prompt_variables,
|
||||||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||||
|
@ -499,7 +505,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
return model, messages, non_default_params
|
return model, messages, non_default_params
|
||||||
|
|
||||||
def get_custom_logger_for_prompt_management(
|
def get_custom_logger_for_prompt_management(
|
||||||
self, model: str
|
self, model: str, non_default_params: Dict
|
||||||
) -> Optional[CustomLogger]:
|
) -> Optional[CustomLogger]:
|
||||||
"""
|
"""
|
||||||
Get a custom logger for prompt management based on model name or available callbacks.
|
Get a custom logger for prompt management based on model name or available callbacks.
|
||||||
|
@ -534,6 +540,26 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.model_call_details["prompt_integration"] = logger.__class__.__name__
|
self.model_call_details["prompt_integration"] = logger.__class__.__name__
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
|
||||||
|
non_default_params
|
||||||
|
):
|
||||||
|
self.model_call_details["prompt_integration"] = (
|
||||||
|
anthropic_cache_control_logger.__class__.__name__
|
||||||
|
)
|
||||||
|
return anthropic_cache_control_logger
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_custom_logger_for_anthropic_cache_control_hook(
|
||||||
|
self, non_default_params: Dict
|
||||||
|
) -> Optional[CustomLogger]:
|
||||||
|
if non_default_params.get("cache_control_injection_points", None):
|
||||||
|
custom_logger = _init_custom_logger_compatible_class(
|
||||||
|
logging_integration="anthropic_cache_control_hook",
|
||||||
|
internal_usage_cache=None,
|
||||||
|
llm_router=None,
|
||||||
|
)
|
||||||
|
return custom_logger
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
|
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
|
||||||
|
@ -2922,6 +2948,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args)
|
pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args)
|
||||||
_in_memory_loggers.append(pagerduty_logger)
|
_in_memory_loggers.append(pagerduty_logger)
|
||||||
return pagerduty_logger # type: ignore
|
return pagerduty_logger # type: ignore
|
||||||
|
elif logging_integration == "anthropic_cache_control_hook":
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, AnthropicCacheControlHook):
|
||||||
|
return callback
|
||||||
|
anthropic_cache_control_hook = AnthropicCacheControlHook()
|
||||||
|
_in_memory_loggers.append(anthropic_cache_control_hook)
|
||||||
|
return anthropic_cache_control_hook # type: ignore
|
||||||
elif logging_integration == "gcs_pubsub":
|
elif logging_integration == "gcs_pubsub":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, GcsPubSubLogger):
|
if isinstance(callback, GcsPubSubLogger):
|
||||||
|
@ -3060,6 +3093,10 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, PagerDutyAlerting):
|
if isinstance(callback, PagerDutyAlerting):
|
||||||
return callback
|
return callback
|
||||||
|
elif logging_integration == "anthropic_cache_control_hook":
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, AnthropicCacheControlHook):
|
||||||
|
return callback
|
||||||
elif logging_integration == "gcs_pubsub":
|
elif logging_integration == "gcs_pubsub":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, GcsPubSubLogger):
|
if isinstance(callback, GcsPubSubLogger):
|
||||||
|
|
|
@ -12,7 +12,7 @@ class X42PromptManagement(CustomPromptManagement):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: Optional[str],
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
|
|
|
@ -10,7 +10,7 @@ class CacheControlMessageInjectionPoint(TypedDict):
|
||||||
role: Optional[
|
role: Optional[
|
||||||
Literal["user", "system", "assistant"]
|
Literal["user", "system", "assistant"]
|
||||||
] # Optional: target by role (user, system, assistant)
|
] # Optional: target by role (user, system, assistant)
|
||||||
index: Optional[int] # Optional: target by specific index
|
index: Optional[Union[int, str]] # Optional: target by specific index
|
||||||
control: Optional[ChatCompletionCachedContent]
|
control: Optional[ChatCompletionCachedContent]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ class TestCustomPromptManagement(CustomPromptManagement):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: Optional[str],
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
|
|
|
@ -33,6 +33,7 @@ from litellm.integrations.opik.opik import OpikLogger
|
||||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||||
from litellm.integrations.mlflow import MlflowLogger
|
from litellm.integrations.mlflow import MlflowLogger
|
||||||
from litellm.integrations.argilla import ArgillaLogger
|
from litellm.integrations.argilla import ArgillaLogger
|
||||||
|
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
|
||||||
from litellm.integrations.langfuse.langfuse_prompt_management import (
|
from litellm.integrations.langfuse.langfuse_prompt_management import (
|
||||||
LangfusePromptManagement,
|
LangfusePromptManagement,
|
||||||
)
|
)
|
||||||
|
@ -73,6 +74,7 @@ callback_class_str_to_classType = {
|
||||||
"otel": OpenTelemetry,
|
"otel": OpenTelemetry,
|
||||||
"pagerduty": PagerDutyAlerting,
|
"pagerduty": PagerDutyAlerting,
|
||||||
"gcs_pubsub": GcsPubSubLogger,
|
"gcs_pubsub": GcsPubSubLogger,
|
||||||
|
"anthropic_cache_control_hook": AnthropicCacheControlHook,
|
||||||
}
|
}
|
||||||
|
|
||||||
expected_env_vars = {
|
expected_env_vars = {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import { Row, Col, Typography, Card } from "antd";
|
||||||
import TextArea from "antd/es/input/TextArea";
|
import TextArea from "antd/es/input/TextArea";
|
||||||
import { Team } from "../key_team_helpers/key_list";
|
import { Team } from "../key_team_helpers/key_list";
|
||||||
import TeamDropdown from "../common_components/team_dropdown";
|
import TeamDropdown from "../common_components/team_dropdown";
|
||||||
|
import CacheControlSettings from "./cache_control_settings";
|
||||||
const { Link } = Typography;
|
const { Link } = Typography;
|
||||||
|
|
||||||
interface AdvancedSettingsProps {
|
interface AdvancedSettingsProps {
|
||||||
|
@ -21,6 +22,7 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
|
||||||
const [form] = Form.useForm();
|
const [form] = Form.useForm();
|
||||||
const [customPricing, setCustomPricing] = React.useState(false);
|
const [customPricing, setCustomPricing] = React.useState(false);
|
||||||
const [pricingModel, setPricingModel] = React.useState<'per_token' | 'per_second'>('per_token');
|
const [pricingModel, setPricingModel] = React.useState<'per_token' | 'per_second'>('per_token');
|
||||||
|
const [showCacheControl, setShowCacheControl] = React.useState(false);
|
||||||
|
|
||||||
// Add validation function for numbers
|
// Add validation function for numbers
|
||||||
const validateNumber = (_: any, value: string) => {
|
const validateNumber = (_: any, value: string) => {
|
||||||
|
@ -83,6 +85,24 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleCacheControlChange = (checked: boolean) => {
|
||||||
|
setShowCacheControl(checked);
|
||||||
|
if (!checked) {
|
||||||
|
const currentParams = form.getFieldValue('litellm_extra_params');
|
||||||
|
try {
|
||||||
|
let paramsObj = currentParams ? JSON.parse(currentParams) : {};
|
||||||
|
delete paramsObj.cache_control_injection_points;
|
||||||
|
if (Object.keys(paramsObj).length > 0) {
|
||||||
|
form.setFieldValue('litellm_extra_params', JSON.stringify(paramsObj, null, 2));
|
||||||
|
} else {
|
||||||
|
form.setFieldValue('litellm_extra_params', '');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
form.setFieldValue('litellm_extra_params', '');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Accordion className="mt-2 mb-4">
|
<Accordion className="mt-2 mb-4">
|
||||||
|
@ -150,6 +170,12 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
<CacheControlSettings
|
||||||
|
form={form}
|
||||||
|
showCacheControl={showCacheControl}
|
||||||
|
onCacheControlChange={handleCacheControlChange}
|
||||||
|
/>
|
||||||
|
|
||||||
<Form.Item
|
<Form.Item
|
||||||
label="Use in pass through routes"
|
label="Use in pass through routes"
|
||||||
name="use_in_pass_through"
|
name="use_in_pass_through"
|
||||||
|
|
|
@ -0,0 +1,159 @@
|
||||||
|
import React from "react";
|
||||||
|
import { Form, Switch, Select, Input, Typography } from "antd";
|
||||||
|
import { PlusOutlined, MinusCircleOutlined } from '@ant-design/icons';
|
||||||
|
import NumericalInput from "../shared/numerical_input";
|
||||||
|
|
||||||
|
const { Text } = Typography;
|
||||||
|
|
||||||
|
interface CacheControlInjectionPoint {
|
||||||
|
location: "message";
|
||||||
|
role?: "user" | "system" | "assistant";
|
||||||
|
index?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface CacheControlSettingsProps {
|
||||||
|
form: any; // Form instance from parent
|
||||||
|
showCacheControl: boolean;
|
||||||
|
onCacheControlChange: (checked: boolean) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CacheControlSettings: React.FC<CacheControlSettingsProps> = ({
|
||||||
|
form,
|
||||||
|
showCacheControl,
|
||||||
|
onCacheControlChange,
|
||||||
|
}) => {
|
||||||
|
const updateCacheControlPoints = (injectionPoints: CacheControlInjectionPoint[]) => {
|
||||||
|
const currentParams = form.getFieldValue('litellm_extra_params');
|
||||||
|
try {
|
||||||
|
let paramsObj = currentParams ? JSON.parse(currentParams) : {};
|
||||||
|
if (injectionPoints.length > 0) {
|
||||||
|
paramsObj.cache_control_injection_points = injectionPoints;
|
||||||
|
} else {
|
||||||
|
delete paramsObj.cache_control_injection_points;
|
||||||
|
}
|
||||||
|
if (Object.keys(paramsObj).length > 0) {
|
||||||
|
form.setFieldValue('litellm_extra_params', JSON.stringify(paramsObj, null, 2));
|
||||||
|
} else {
|
||||||
|
form.setFieldValue('litellm_extra_params', '');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error updating cache control points:', error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Form.Item
|
||||||
|
label="Cache Control"
|
||||||
|
name="cache_control"
|
||||||
|
valuePropName="checked"
|
||||||
|
className="mb-4"
|
||||||
|
tooltip="Tell litellm where to inject cache control checkpoints. You can specify either by role (to apply to all messages of that role) or by specific message index."
|
||||||
|
>
|
||||||
|
<Switch onChange={onCacheControlChange} className="bg-gray-600" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
{showCacheControl && (
|
||||||
|
<div className="ml-6 pl-4 border-l-2 border-gray-200">
|
||||||
|
<Text className="text-sm text-gray-500 block mb-4">
|
||||||
|
Specify either a role (to cache all messages of that role) or a specific message index.
|
||||||
|
If both are provided, the index takes precedence.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<Form.List
|
||||||
|
name="cache_control_injection_points"
|
||||||
|
initialValue={[{ location: "message" }]}
|
||||||
|
>
|
||||||
|
{(fields, { add, remove }) => (
|
||||||
|
<>
|
||||||
|
{fields.map((field, index) => (
|
||||||
|
<div key={field.key} className="flex items-center mb-4 gap-4">
|
||||||
|
<Form.Item
|
||||||
|
{...field}
|
||||||
|
label="Type"
|
||||||
|
name={[field.name, 'location']}
|
||||||
|
initialValue="message"
|
||||||
|
className="mb-0"
|
||||||
|
style={{ width: '180px' }}
|
||||||
|
>
|
||||||
|
<Select disabled options={[{ value: 'message', label: 'Message' }]} />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
{...field}
|
||||||
|
label="Role"
|
||||||
|
name={[field.name, 'role']}
|
||||||
|
className="mb-0"
|
||||||
|
style={{ width: '180px' }}
|
||||||
|
tooltip="Select a role to cache all messages of this type"
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
placeholder="Select a role"
|
||||||
|
allowClear
|
||||||
|
options={[
|
||||||
|
{ value: 'user', label: 'User' },
|
||||||
|
{ value: 'system', label: 'System' },
|
||||||
|
{ value: 'assistant', label: 'Assistant' },
|
||||||
|
]}
|
||||||
|
onChange={() => {
|
||||||
|
const values = form.getFieldValue('cache_control_points');
|
||||||
|
updateCacheControlPoints(values);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
{...field}
|
||||||
|
label="Index"
|
||||||
|
name={[field.name, 'index']}
|
||||||
|
className="mb-0"
|
||||||
|
style={{ width: '180px' }}
|
||||||
|
tooltip="Specify a specific message index (optional)"
|
||||||
|
>
|
||||||
|
<NumericalInput
|
||||||
|
type="number"
|
||||||
|
placeholder="Optional"
|
||||||
|
step={1}
|
||||||
|
min={0}
|
||||||
|
onChange={() => {
|
||||||
|
const values = form.getFieldValue('cache_control_points');
|
||||||
|
updateCacheControlPoints(values);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
{fields.length > 1 && (
|
||||||
|
<MinusCircleOutlined
|
||||||
|
className="text-red-500 cursor-pointer text-lg mt-8"
|
||||||
|
onClick={() => {
|
||||||
|
remove(field.name);
|
||||||
|
setTimeout(() => {
|
||||||
|
const values = form.getFieldValue('cache_control_points');
|
||||||
|
updateCacheControlPoints(values);
|
||||||
|
}, 0);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
|
||||||
|
<Form.Item>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
className="flex items-center justify-center w-full border border-dashed border-gray-300 py-2 px-4 text-gray-600 hover:text-blue-600 hover:border-blue-300 transition-all rounded"
|
||||||
|
onClick={() => add()}
|
||||||
|
>
|
||||||
|
<PlusOutlined className="mr-2" />
|
||||||
|
Add Injection Point
|
||||||
|
</button>
|
||||||
|
</Form.Item>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Form.List>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default CacheControlSettings;
|
|
@ -60,7 +60,7 @@ export const prepareModelAddRequest = async (
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Skip the custom_pricing and pricing_model fields as they're only used for UI control
|
// Skip the custom_pricing and pricing_model fields as they're only used for UI control
|
||||||
if (key === 'custom_pricing' || key === 'pricing_model') {
|
if (key === 'custom_pricing' || key === 'pricing_model' || key === 'cache_control') {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (key == "model_name") {
|
if (key == "model_name") {
|
||||||
|
|
|
@ -23,6 +23,7 @@ import { getProviderLogoAndName } from "./provider_info_helpers";
|
||||||
import { getDisplayModelName } from "./view_model/model_name_display";
|
import { getDisplayModelName } from "./view_model/model_name_display";
|
||||||
import AddCredentialsModal from "./model_add/add_credentials_tab";
|
import AddCredentialsModal from "./model_add/add_credentials_tab";
|
||||||
import ReuseCredentialsModal from "./model_add/reuse_credentials";
|
import ReuseCredentialsModal from "./model_add/reuse_credentials";
|
||||||
|
import CacheControlSettings from "./add_model/cache_control_settings";
|
||||||
|
|
||||||
interface ModelInfoViewProps {
|
interface ModelInfoViewProps {
|
||||||
modelId: string;
|
modelId: string;
|
||||||
|
@ -57,6 +58,7 @@ export default function ModelInfoView({
|
||||||
const [isSaving, setIsSaving] = useState(false);
|
const [isSaving, setIsSaving] = useState(false);
|
||||||
const [isEditing, setIsEditing] = useState(false);
|
const [isEditing, setIsEditing] = useState(false);
|
||||||
const [existingCredential, setExistingCredential] = useState<CredentialItem | null>(null);
|
const [existingCredential, setExistingCredential] = useState<CredentialItem | null>(null);
|
||||||
|
const [showCacheControl, setShowCacheControl] = useState(false);
|
||||||
|
|
||||||
const canEditModel = userRole === "Admin" || modelData.model_info.created_by === userID;
|
const canEditModel = userRole === "Admin" || modelData.model_info.created_by === userID;
|
||||||
const isAdmin = userRole === "Admin";
|
const isAdmin = userRole === "Admin";
|
||||||
|
@ -86,6 +88,11 @@ export default function ModelInfoView({
|
||||||
console.log("modelInfoResponse, ", modelInfoResponse);
|
console.log("modelInfoResponse, ", modelInfoResponse);
|
||||||
let specificModelData = modelInfoResponse.data[0];
|
let specificModelData = modelInfoResponse.data[0];
|
||||||
setLocalModelData(specificModelData);
|
setLocalModelData(specificModelData);
|
||||||
|
|
||||||
|
// Check if cache control is enabled
|
||||||
|
if (specificModelData?.litellm_params?.cache_control_injection_points) {
|
||||||
|
setShowCacheControl(true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
getExistingCredential();
|
getExistingCredential();
|
||||||
getModelInfo();
|
getModelInfo();
|
||||||
|
@ -112,22 +119,31 @@ export default function ModelInfoView({
|
||||||
if (!accessToken) return;
|
if (!accessToken) return;
|
||||||
setIsSaving(true);
|
setIsSaving(true);
|
||||||
|
|
||||||
|
let updatedLitellmParams = {
|
||||||
|
...localModelData.litellm_params,
|
||||||
|
model: values.litellm_model_name,
|
||||||
|
api_base: values.api_base,
|
||||||
|
custom_llm_provider: values.custom_llm_provider,
|
||||||
|
organization: values.organization,
|
||||||
|
tpm: values.tpm,
|
||||||
|
rpm: values.rpm,
|
||||||
|
max_retries: values.max_retries,
|
||||||
|
timeout: values.timeout,
|
||||||
|
stream_timeout: values.stream_timeout,
|
||||||
|
input_cost_per_token: values.input_cost / 1_000_000,
|
||||||
|
output_cost_per_token: values.output_cost / 1_000_000,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle cache control settings
|
||||||
|
if (values.cache_control && values.cache_control_injection_points?.length > 0) {
|
||||||
|
updatedLitellmParams.cache_control_injection_points = values.cache_control_injection_points;
|
||||||
|
} else {
|
||||||
|
delete updatedLitellmParams.cache_control_injection_points;
|
||||||
|
}
|
||||||
|
|
||||||
const updateData = {
|
const updateData = {
|
||||||
model_name: values.model_name,
|
model_name: values.model_name,
|
||||||
litellm_params: {
|
litellm_params: updatedLitellmParams,
|
||||||
...localModelData.litellm_params,
|
|
||||||
model: values.litellm_model_name,
|
|
||||||
api_base: values.api_base,
|
|
||||||
custom_llm_provider: values.custom_llm_provider,
|
|
||||||
organization: values.organization,
|
|
||||||
tpm: values.tpm,
|
|
||||||
rpm: values.rpm,
|
|
||||||
max_retries: values.max_retries,
|
|
||||||
timeout: values.timeout,
|
|
||||||
stream_timeout: values.stream_timeout,
|
|
||||||
input_cost_per_token: values.input_cost / 1_000_000,
|
|
||||||
output_cost_per_token: values.output_cost / 1_000_000,
|
|
||||||
},
|
|
||||||
model_info: {
|
model_info: {
|
||||||
id: modelId,
|
id: modelId,
|
||||||
}
|
}
|
||||||
|
@ -139,7 +155,7 @@ export default function ModelInfoView({
|
||||||
...localModelData,
|
...localModelData,
|
||||||
model_name: values.model_name,
|
model_name: values.model_name,
|
||||||
litellm_model_name: values.litellm_model_name,
|
litellm_model_name: values.litellm_model_name,
|
||||||
litellm_params: updateData.litellm_params
|
litellm_params: updatedLitellmParams
|
||||||
};
|
};
|
||||||
|
|
||||||
setLocalModelData(updatedModelData);
|
setLocalModelData(updatedModelData);
|
||||||
|
@ -337,6 +353,8 @@ export default function ModelInfoView({
|
||||||
(localModelData.litellm_params.input_cost_per_token * 1_000_000) : localModelData.model_info?.input_cost_per_token * 1_000_000 || null,
|
(localModelData.litellm_params.input_cost_per_token * 1_000_000) : localModelData.model_info?.input_cost_per_token * 1_000_000 || null,
|
||||||
output_cost: localModelData.litellm_params?.output_cost_per_token ?
|
output_cost: localModelData.litellm_params?.output_cost_per_token ?
|
||||||
(localModelData.litellm_params.output_cost_per_token * 1_000_000) : localModelData.model_info?.output_cost_per_token * 1_000_000 || null,
|
(localModelData.litellm_params.output_cost_per_token * 1_000_000) : localModelData.model_info?.output_cost_per_token * 1_000_000 || null,
|
||||||
|
cache_control: localModelData.litellm_params?.cache_control_injection_points ? true : false,
|
||||||
|
cache_control_injection_points: localModelData.litellm_params?.cache_control_injection_points || [],
|
||||||
}}
|
}}
|
||||||
layout="vertical"
|
layout="vertical"
|
||||||
onValuesChange={() => setIsDirty(true)}
|
onValuesChange={() => setIsDirty(true)}
|
||||||
|
@ -499,6 +517,37 @@ export default function ModelInfoView({
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* Cache Control Section */}
|
||||||
|
{isEditing ? (
|
||||||
|
<CacheControlSettings
|
||||||
|
form={form}
|
||||||
|
showCacheControl={showCacheControl}
|
||||||
|
onCacheControlChange={(checked) => setShowCacheControl(checked)}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<div>
|
||||||
|
<Text className="font-medium">Cache Control</Text>
|
||||||
|
<div className="mt-1 p-2 bg-gray-50 rounded">
|
||||||
|
{localModelData.litellm_params?.cache_control_injection_points ? (
|
||||||
|
<div>
|
||||||
|
<p>Enabled</p>
|
||||||
|
<div className="mt-2">
|
||||||
|
{localModelData.litellm_params.cache_control_injection_points.map((point: any, i: number) => (
|
||||||
|
<div key={i} className="text-sm text-gray-600 mb-1">
|
||||||
|
Location: {point.location},
|
||||||
|
{point.role && <span> Role: {point.role}</span>}
|
||||||
|
{point.index !== undefined && <span> Index: {point.index}</span>}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
"Disabled"
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<Text className="font-medium">Team ID</Text>
|
<Text className="font-medium">Team ID</Text>
|
||||||
<div className="mt-1 p-2 bg-gray-50 rounded">
|
<div className="mt-1 p-2 bg-gray-50 rounded">
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue