mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
[UI] Allow setting prompt cache_control_injection_points
(#10000)
* test_anthropic_cache_control_hook_system_message * test_anthropic_cache_control_hook.py * should_run_prompt_management_hooks * fix should_run_prompt_management_hooks * test_anthropic_cache_control_hook_specific_index * fix test * fix linting errors * ChatCompletionCachedContent * initial commit for cache control * fixes ui design * fix inserting cache_control_injection_points * fix entering cache control points * fixes for using cache control on ui + backend * update cache control settings on edit model page * fix init custom logger compatible class * fix linting errors * fix linting errors * fix get_chat_completion_prompt
This commit is contained in:
parent
6cfa50d278
commit
c1a642ce20
16 changed files with 358 additions and 39 deletions
|
@ -113,6 +113,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
|
|||
"pagerduty",
|
||||
"humanloop",
|
||||
"gcs_pubsub",
|
||||
"anthropic_cache_control_hook",
|
||||
]
|
||||
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
|
||||
_known_custom_logger_compatible_callbacks: List = list(
|
||||
|
|
|
@ -7,8 +7,9 @@ Users can define
|
|||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.types.integrations.anthropic_cache_control_hook import (
|
||||
CacheControlInjectionPoint,
|
||||
|
@ -24,7 +25,7 @@ class AnthropicCacheControlHook(CustomPromptManagement):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
|
@ -64,8 +65,15 @@ class AnthropicCacheControlHook(CustomPromptManagement):
|
|||
control: ChatCompletionCachedContent = point.get(
|
||||
"control", None
|
||||
) or ChatCompletionCachedContent(type="ephemeral")
|
||||
targetted_index = point.get("index", None)
|
||||
targetted_index = point.get("index", None)
|
||||
|
||||
_targetted_index: Optional[Union[int, str]] = point.get("index", None)
|
||||
targetted_index: Optional[int] = None
|
||||
if isinstance(_targetted_index, str):
|
||||
if _targetted_index.isdigit():
|
||||
targetted_index = int(_targetted_index)
|
||||
else:
|
||||
targetted_index = _targetted_index
|
||||
|
||||
targetted_role = point.get("role", None)
|
||||
|
||||
# Case 1: Target by specific index
|
||||
|
@ -115,4 +123,28 @@ class AnthropicCacheControlHook(CustomPromptManagement):
|
|||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Return the integration name for this hook."""
|
||||
return "anthropic-cache-control-hook"
|
||||
return "anthropic_cache_control_hook"
|
||||
|
||||
@staticmethod
|
||||
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool:
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params: Dict,
|
||||
) -> Optional[CustomLogger]:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
return _init_custom_logger_compatible_class(
|
||||
logging_integration="anthropic_cache_control_hook",
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
return None
|
||||
|
|
|
@ -94,7 +94,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
|
|
|
@ -15,7 +15,7 @@ class CustomPromptManagement(CustomLogger, PromptManagementBase):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
|
|
|
@ -152,14 +152,21 @@ class HumanloopLogger(CustomLogger):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||
) -> Tuple[
|
||||
str,
|
||||
List[AllMessageValues],
|
||||
dict,
|
||||
]:
|
||||
humanloop_api_key = dynamic_callback_params.get(
|
||||
"humanloop_api_key"
|
||||
) or get_secret_str("HUMANLOOP_API_KEY")
|
||||
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for Humanloop integration")
|
||||
|
||||
if humanloop_api_key is None:
|
||||
return super().get_chat_completion_prompt(
|
||||
model=model,
|
||||
|
|
|
@ -169,10 +169,14 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||
) -> Tuple[
|
||||
str,
|
||||
List[AllMessageValues],
|
||||
dict,
|
||||
]:
|
||||
return self.get_chat_completion_prompt(
|
||||
model,
|
||||
messages,
|
||||
|
|
|
@ -79,10 +79,12 @@ class PromptManagementBase(ABC):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
if prompt_id is None:
|
||||
raise ValueError("prompt_id is required for Prompt Management Base class")
|
||||
if not self.should_run_prompt_management(
|
||||
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
|
||||
):
|
||||
|
|
|
@ -36,6 +36,7 @@ from litellm.cost_calculator import (
|
|||
RealtimeAPITokenUsageProcessor,
|
||||
_select_model_name_for_cost_calc,
|
||||
)
|
||||
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -457,15 +458,17 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
def should_run_prompt_management_hooks(
|
||||
self,
|
||||
prompt_id: str,
|
||||
non_default_params: Dict,
|
||||
prompt_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if prompt management hooks should be run
|
||||
"""
|
||||
if prompt_id:
|
||||
return True
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -473,15 +476,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
non_default_params: Dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
prompt_management_logger: Optional[CustomLogger] = None,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
custom_logger = (
|
||||
prompt_management_logger
|
||||
or self.get_custom_logger_for_prompt_management(model)
|
||||
or self.get_custom_logger_for_prompt_management(
|
||||
model=model, non_default_params=non_default_params
|
||||
)
|
||||
)
|
||||
|
||||
if custom_logger:
|
||||
(
|
||||
model,
|
||||
|
@ -490,7 +496,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
) = custom_logger.get_chat_completion_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
non_default_params=non_default_params or {},
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||
|
@ -499,7 +505,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
return model, messages, non_default_params
|
||||
|
||||
def get_custom_logger_for_prompt_management(
|
||||
self, model: str
|
||||
self, model: str, non_default_params: Dict
|
||||
) -> Optional[CustomLogger]:
|
||||
"""
|
||||
Get a custom logger for prompt management based on model name or available callbacks.
|
||||
|
@ -534,6 +540,26 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["prompt_integration"] = logger.__class__.__name__
|
||||
return logger
|
||||
|
||||
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
anthropic_cache_control_logger.__class__.__name__
|
||||
)
|
||||
return anthropic_cache_control_logger
|
||||
|
||||
return None
|
||||
|
||||
def get_custom_logger_for_anthropic_cache_control_hook(
|
||||
self, non_default_params: Dict
|
||||
) -> Optional[CustomLogger]:
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
custom_logger = _init_custom_logger_compatible_class(
|
||||
logging_integration="anthropic_cache_control_hook",
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
return custom_logger
|
||||
return None
|
||||
|
||||
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
|
||||
|
@ -2922,6 +2948,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args)
|
||||
_in_memory_loggers.append(pagerduty_logger)
|
||||
return pagerduty_logger # type: ignore
|
||||
elif logging_integration == "anthropic_cache_control_hook":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, AnthropicCacheControlHook):
|
||||
return callback
|
||||
anthropic_cache_control_hook = AnthropicCacheControlHook()
|
||||
_in_memory_loggers.append(anthropic_cache_control_hook)
|
||||
return anthropic_cache_control_hook # type: ignore
|
||||
elif logging_integration == "gcs_pubsub":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, GcsPubSubLogger):
|
||||
|
@ -3060,6 +3093,10 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
|
|||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, PagerDutyAlerting):
|
||||
return callback
|
||||
elif logging_integration == "anthropic_cache_control_hook":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, AnthropicCacheControlHook):
|
||||
return callback
|
||||
elif logging_integration == "gcs_pubsub":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, GcsPubSubLogger):
|
||||
|
|
|
@ -12,7 +12,7 @@ class X42PromptManagement(CustomPromptManagement):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
|
|
|
@ -10,7 +10,7 @@ class CacheControlMessageInjectionPoint(TypedDict):
|
|||
role: Optional[
|
||||
Literal["user", "system", "assistant"]
|
||||
] # Optional: target by role (user, system, assistant)
|
||||
index: Optional[int] # Optional: target by specific index
|
||||
index: Optional[Union[int, str]] # Optional: target by specific index
|
||||
control: Optional[ChatCompletionCachedContent]
|
||||
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class TestCustomPromptManagement(CustomPromptManagement):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
|
|
|
@ -33,6 +33,7 @@ from litellm.integrations.opik.opik import OpikLogger
|
|||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
from litellm.integrations.mlflow import MlflowLogger
|
||||
from litellm.integrations.argilla import ArgillaLogger
|
||||
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
|
||||
from litellm.integrations.langfuse.langfuse_prompt_management import (
|
||||
LangfusePromptManagement,
|
||||
)
|
||||
|
@ -73,6 +74,7 @@ callback_class_str_to_classType = {
|
|||
"otel": OpenTelemetry,
|
||||
"pagerduty": PagerDutyAlerting,
|
||||
"gcs_pubsub": GcsPubSubLogger,
|
||||
"anthropic_cache_control_hook": AnthropicCacheControlHook,
|
||||
}
|
||||
|
||||
expected_env_vars = {
|
||||
|
|
|
@ -5,6 +5,7 @@ import { Row, Col, Typography, Card } from "antd";
|
|||
import TextArea from "antd/es/input/TextArea";
|
||||
import { Team } from "../key_team_helpers/key_list";
|
||||
import TeamDropdown from "../common_components/team_dropdown";
|
||||
import CacheControlSettings from "./cache_control_settings";
|
||||
const { Link } = Typography;
|
||||
|
||||
interface AdvancedSettingsProps {
|
||||
|
@ -21,6 +22,7 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
|
|||
const [form] = Form.useForm();
|
||||
const [customPricing, setCustomPricing] = React.useState(false);
|
||||
const [pricingModel, setPricingModel] = React.useState<'per_token' | 'per_second'>('per_token');
|
||||
const [showCacheControl, setShowCacheControl] = React.useState(false);
|
||||
|
||||
// Add validation function for numbers
|
||||
const validateNumber = (_: any, value: string) => {
|
||||
|
@ -83,6 +85,24 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
|
|||
}
|
||||
};
|
||||
|
||||
const handleCacheControlChange = (checked: boolean) => {
|
||||
setShowCacheControl(checked);
|
||||
if (!checked) {
|
||||
const currentParams = form.getFieldValue('litellm_extra_params');
|
||||
try {
|
||||
let paramsObj = currentParams ? JSON.parse(currentParams) : {};
|
||||
delete paramsObj.cache_control_injection_points;
|
||||
if (Object.keys(paramsObj).length > 0) {
|
||||
form.setFieldValue('litellm_extra_params', JSON.stringify(paramsObj, null, 2));
|
||||
} else {
|
||||
form.setFieldValue('litellm_extra_params', '');
|
||||
}
|
||||
} catch (error) {
|
||||
form.setFieldValue('litellm_extra_params', '');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Accordion className="mt-2 mb-4">
|
||||
|
@ -150,6 +170,12 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
|
|||
</div>
|
||||
)}
|
||||
|
||||
<CacheControlSettings
|
||||
form={form}
|
||||
showCacheControl={showCacheControl}
|
||||
onCacheControlChange={handleCacheControlChange}
|
||||
/>
|
||||
|
||||
<Form.Item
|
||||
label="Use in pass through routes"
|
||||
name="use_in_pass_through"
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
import React from "react";
|
||||
import { Form, Switch, Select, Input, Typography } from "antd";
|
||||
import { PlusOutlined, MinusCircleOutlined } from '@ant-design/icons';
|
||||
import NumericalInput from "../shared/numerical_input";
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
interface CacheControlInjectionPoint {
|
||||
location: "message";
|
||||
role?: "user" | "system" | "assistant";
|
||||
index?: number;
|
||||
}
|
||||
|
||||
interface CacheControlSettingsProps {
|
||||
form: any; // Form instance from parent
|
||||
showCacheControl: boolean;
|
||||
onCacheControlChange: (checked: boolean) => void;
|
||||
}
|
||||
|
||||
const CacheControlSettings: React.FC<CacheControlSettingsProps> = ({
|
||||
form,
|
||||
showCacheControl,
|
||||
onCacheControlChange,
|
||||
}) => {
|
||||
const updateCacheControlPoints = (injectionPoints: CacheControlInjectionPoint[]) => {
|
||||
const currentParams = form.getFieldValue('litellm_extra_params');
|
||||
try {
|
||||
let paramsObj = currentParams ? JSON.parse(currentParams) : {};
|
||||
if (injectionPoints.length > 0) {
|
||||
paramsObj.cache_control_injection_points = injectionPoints;
|
||||
} else {
|
||||
delete paramsObj.cache_control_injection_points;
|
||||
}
|
||||
if (Object.keys(paramsObj).length > 0) {
|
||||
form.setFieldValue('litellm_extra_params', JSON.stringify(paramsObj, null, 2));
|
||||
} else {
|
||||
form.setFieldValue('litellm_extra_params', '');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error updating cache control points:', error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Form.Item
|
||||
label="Cache Control"
|
||||
name="cache_control"
|
||||
valuePropName="checked"
|
||||
className="mb-4"
|
||||
tooltip="Tell litellm where to inject cache control checkpoints. You can specify either by role (to apply to all messages of that role) or by specific message index."
|
||||
>
|
||||
<Switch onChange={onCacheControlChange} className="bg-gray-600" />
|
||||
</Form.Item>
|
||||
|
||||
{showCacheControl && (
|
||||
<div className="ml-6 pl-4 border-l-2 border-gray-200">
|
||||
<Text className="text-sm text-gray-500 block mb-4">
|
||||
Specify either a role (to cache all messages of that role) or a specific message index.
|
||||
If both are provided, the index takes precedence.
|
||||
</Text>
|
||||
|
||||
<Form.List
|
||||
name="cache_control_injection_points"
|
||||
initialValue={[{ location: "message" }]}
|
||||
>
|
||||
{(fields, { add, remove }) => (
|
||||
<>
|
||||
{fields.map((field, index) => (
|
||||
<div key={field.key} className="flex items-center mb-4 gap-4">
|
||||
<Form.Item
|
||||
{...field}
|
||||
label="Type"
|
||||
name={[field.name, 'location']}
|
||||
initialValue="message"
|
||||
className="mb-0"
|
||||
style={{ width: '180px' }}
|
||||
>
|
||||
<Select disabled options={[{ value: 'message', label: 'Message' }]} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
{...field}
|
||||
label="Role"
|
||||
name={[field.name, 'role']}
|
||||
className="mb-0"
|
||||
style={{ width: '180px' }}
|
||||
tooltip="Select a role to cache all messages of this type"
|
||||
>
|
||||
<Select
|
||||
placeholder="Select a role"
|
||||
allowClear
|
||||
options={[
|
||||
{ value: 'user', label: 'User' },
|
||||
{ value: 'system', label: 'System' },
|
||||
{ value: 'assistant', label: 'Assistant' },
|
||||
]}
|
||||
onChange={() => {
|
||||
const values = form.getFieldValue('cache_control_points');
|
||||
updateCacheControlPoints(values);
|
||||
}}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
{...field}
|
||||
label="Index"
|
||||
name={[field.name, 'index']}
|
||||
className="mb-0"
|
||||
style={{ width: '180px' }}
|
||||
tooltip="Specify a specific message index (optional)"
|
||||
>
|
||||
<NumericalInput
|
||||
type="number"
|
||||
placeholder="Optional"
|
||||
step={1}
|
||||
min={0}
|
||||
onChange={() => {
|
||||
const values = form.getFieldValue('cache_control_points');
|
||||
updateCacheControlPoints(values);
|
||||
}}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
{fields.length > 1 && (
|
||||
<MinusCircleOutlined
|
||||
className="text-red-500 cursor-pointer text-lg mt-8"
|
||||
onClick={() => {
|
||||
remove(field.name);
|
||||
setTimeout(() => {
|
||||
const values = form.getFieldValue('cache_control_points');
|
||||
updateCacheControlPoints(values);
|
||||
}, 0);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Form.Item>
|
||||
<button
|
||||
type="button"
|
||||
className="flex items-center justify-center w-full border border-dashed border-gray-300 py-2 px-4 text-gray-600 hover:text-blue-600 hover:border-blue-300 transition-all rounded"
|
||||
onClick={() => add()}
|
||||
>
|
||||
<PlusOutlined className="mr-2" />
|
||||
Add Injection Point
|
||||
</button>
|
||||
</Form.Item>
|
||||
</>
|
||||
)}
|
||||
</Form.List>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default CacheControlSettings;
|
|
@ -60,7 +60,7 @@ export const prepareModelAddRequest = async (
|
|||
continue;
|
||||
}
|
||||
// Skip the custom_pricing and pricing_model fields as they're only used for UI control
|
||||
if (key === 'custom_pricing' || key === 'pricing_model') {
|
||||
if (key === 'custom_pricing' || key === 'pricing_model' || key === 'cache_control') {
|
||||
continue;
|
||||
}
|
||||
if (key == "model_name") {
|
||||
|
|
|
@ -23,6 +23,7 @@ import { getProviderLogoAndName } from "./provider_info_helpers";
|
|||
import { getDisplayModelName } from "./view_model/model_name_display";
|
||||
import AddCredentialsModal from "./model_add/add_credentials_tab";
|
||||
import ReuseCredentialsModal from "./model_add/reuse_credentials";
|
||||
import CacheControlSettings from "./add_model/cache_control_settings";
|
||||
|
||||
interface ModelInfoViewProps {
|
||||
modelId: string;
|
||||
|
@ -57,6 +58,7 @@ export default function ModelInfoView({
|
|||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [existingCredential, setExistingCredential] = useState<CredentialItem | null>(null);
|
||||
const [showCacheControl, setShowCacheControl] = useState(false);
|
||||
|
||||
const canEditModel = userRole === "Admin" || modelData.model_info.created_by === userID;
|
||||
const isAdmin = userRole === "Admin";
|
||||
|
@ -86,6 +88,11 @@ export default function ModelInfoView({
|
|||
console.log("modelInfoResponse, ", modelInfoResponse);
|
||||
let specificModelData = modelInfoResponse.data[0];
|
||||
setLocalModelData(specificModelData);
|
||||
|
||||
// Check if cache control is enabled
|
||||
if (specificModelData?.litellm_params?.cache_control_injection_points) {
|
||||
setShowCacheControl(true);
|
||||
}
|
||||
}
|
||||
getExistingCredential();
|
||||
getModelInfo();
|
||||
|
@ -112,22 +119,31 @@ export default function ModelInfoView({
|
|||
if (!accessToken) return;
|
||||
setIsSaving(true);
|
||||
|
||||
let updatedLitellmParams = {
|
||||
...localModelData.litellm_params,
|
||||
model: values.litellm_model_name,
|
||||
api_base: values.api_base,
|
||||
custom_llm_provider: values.custom_llm_provider,
|
||||
organization: values.organization,
|
||||
tpm: values.tpm,
|
||||
rpm: values.rpm,
|
||||
max_retries: values.max_retries,
|
||||
timeout: values.timeout,
|
||||
stream_timeout: values.stream_timeout,
|
||||
input_cost_per_token: values.input_cost / 1_000_000,
|
||||
output_cost_per_token: values.output_cost / 1_000_000,
|
||||
};
|
||||
|
||||
// Handle cache control settings
|
||||
if (values.cache_control && values.cache_control_injection_points?.length > 0) {
|
||||
updatedLitellmParams.cache_control_injection_points = values.cache_control_injection_points;
|
||||
} else {
|
||||
delete updatedLitellmParams.cache_control_injection_points;
|
||||
}
|
||||
|
||||
const updateData = {
|
||||
model_name: values.model_name,
|
||||
litellm_params: {
|
||||
...localModelData.litellm_params,
|
||||
model: values.litellm_model_name,
|
||||
api_base: values.api_base,
|
||||
custom_llm_provider: values.custom_llm_provider,
|
||||
organization: values.organization,
|
||||
tpm: values.tpm,
|
||||
rpm: values.rpm,
|
||||
max_retries: values.max_retries,
|
||||
timeout: values.timeout,
|
||||
stream_timeout: values.stream_timeout,
|
||||
input_cost_per_token: values.input_cost / 1_000_000,
|
||||
output_cost_per_token: values.output_cost / 1_000_000,
|
||||
},
|
||||
litellm_params: updatedLitellmParams,
|
||||
model_info: {
|
||||
id: modelId,
|
||||
}
|
||||
|
@ -139,7 +155,7 @@ export default function ModelInfoView({
|
|||
...localModelData,
|
||||
model_name: values.model_name,
|
||||
litellm_model_name: values.litellm_model_name,
|
||||
litellm_params: updateData.litellm_params
|
||||
litellm_params: updatedLitellmParams
|
||||
};
|
||||
|
||||
setLocalModelData(updatedModelData);
|
||||
|
@ -337,6 +353,8 @@ export default function ModelInfoView({
|
|||
(localModelData.litellm_params.input_cost_per_token * 1_000_000) : localModelData.model_info?.input_cost_per_token * 1_000_000 || null,
|
||||
output_cost: localModelData.litellm_params?.output_cost_per_token ?
|
||||
(localModelData.litellm_params.output_cost_per_token * 1_000_000) : localModelData.model_info?.output_cost_per_token * 1_000_000 || null,
|
||||
cache_control: localModelData.litellm_params?.cache_control_injection_points ? true : false,
|
||||
cache_control_injection_points: localModelData.litellm_params?.cache_control_injection_points || [],
|
||||
}}
|
||||
layout="vertical"
|
||||
onValuesChange={() => setIsDirty(true)}
|
||||
|
@ -499,6 +517,37 @@ export default function ModelInfoView({
|
|||
)}
|
||||
</div>
|
||||
|
||||
{/* Cache Control Section */}
|
||||
{isEditing ? (
|
||||
<CacheControlSettings
|
||||
form={form}
|
||||
showCacheControl={showCacheControl}
|
||||
onCacheControlChange={(checked) => setShowCacheControl(checked)}
|
||||
/>
|
||||
) : (
|
||||
<div>
|
||||
<Text className="font-medium">Cache Control</Text>
|
||||
<div className="mt-1 p-2 bg-gray-50 rounded">
|
||||
{localModelData.litellm_params?.cache_control_injection_points ? (
|
||||
<div>
|
||||
<p>Enabled</p>
|
||||
<div className="mt-2">
|
||||
{localModelData.litellm_params.cache_control_injection_points.map((point: any, i: number) => (
|
||||
<div key={i} className="text-sm text-gray-600 mb-1">
|
||||
Location: {point.location},
|
||||
{point.role && <span> Role: {point.role}</span>}
|
||||
{point.index !== undefined && <span> Index: {point.index}</span>}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
"Disabled"
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div>
|
||||
<Text className="font-medium">Team ID</Text>
|
||||
<div className="mt-1 p-2 bg-gray-50 rounded">
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue