litellm-mirror/litellm/llms/vertex_ai/context_caching/transformation.py
Krish Dholakia c3edfc2c92
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 35s
LiteLLM Minor Fixes & Improvements (12/23/2024) - p3 (#7394)
* build(model_prices_and_context_window.json): add gemini-1.5-flash context caching

* fix(context_caching/transformation.py): just use last identified cache point

Fixes https://github.com/BerriAI/litellm/issues/6738

* fix(context_caching/transformation.py): pick first contiguous block - handles system message error from google

Fixes https://github.com/BerriAI/litellm/issues/6738

* fix(vertex_ai/gemini/): track context caching tokens

* refactor(gemini/): place transformation.py inside `chat/` folder

make it easy for user to know we support the equivalent endpoint

* fix: fix import

* refactor(vertex_ai/): move vertex_ai cost calc inside vertex_ai/ folder

make it easier to see cost calculation logic

* fix: fix linting errors

* fix: fix circular import

* feat(gemini/cost_calculator.py): support gemini context caching cost calculation

generifies anthropic's cost calculation function and uses it across anthropic + gemini

* build(model_prices_and_context_window.json): add cost tracking for gemini-1.5-flash-002 w/ context caching

Closes https://github.com/BerriAI/litellm/issues/6891

* docs(gemini.md): add gemini context caching architecture diagram

make it easier for user to understand how context caching works

* docs(gemini.md): link to relevant gemini context caching code

* docs(gemini/context_caching): add readme in github, make it easy for dev to know context caching is supported + where to go for code

* fix(llm_cost_calc/utils.py): handle gemini 128k token diff cost calc scenario

* fix(deepseek/cost_calculator.py): support deepseek context caching cost calculation

* test: fix test
2024-12-23 22:02:52 -08:00

110 lines
3.5 KiB
Python

"""
Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works
"""
from typing import List, Tuple
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody
from litellm.utils import is_cached_message
from ..common_utils import get_supports_system_message
from ..gemini.transformation import (
_gemini_convert_messages_with_history,
_transform_system_message,
)
def get_first_continuous_block_idx(
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
) -> int:
"""
Find the array index that ends the first continuous sequence of message blocks.
Args:
filtered_messages: List of tuples containing (index, message) pairs
Returns:
int: The array index where the first continuous sequence ends
"""
if not filtered_messages:
return -1
if len(filtered_messages) == 1:
return 0
current_value = filtered_messages[0][0]
# Search forward through the array indices
for i in range(1, len(filtered_messages)):
if filtered_messages[i][0] != current_value + 1:
return i - 1
current_value = filtered_messages[i][0]
# If we made it through the whole list, return the last index
return len(filtered_messages) - 1
def separate_cached_messages(
messages: List[AllMessageValues],
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
"""
Returns separated cached and non-cached messages.
Args:
messages: List of messages to be separated.
Returns:
Tuple containing:
- cached_messages: List of cached messages.
- non_cached_messages: List of non-cached messages.
"""
cached_messages: List[AllMessageValues] = []
non_cached_messages: List[AllMessageValues] = []
# Extract cached messages and their indices
filtered_messages: List[Tuple[int, AllMessageValues]] = []
for idx, message in enumerate(messages):
if is_cached_message(message=message):
filtered_messages.append((idx, message))
# Validate only one block of continuous cached messages
last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
# Separate messages based on the block of cached messages
if filtered_messages and last_continuous_block_idx is not None:
first_cached_idx = filtered_messages[0][0]
last_cached_idx = filtered_messages[last_continuous_block_idx][0]
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
non_cached_messages = (
messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
)
else:
non_cached_messages = messages
return cached_messages, non_cached_messages
def transform_openai_messages_to_gemini_context_caching(
model: str, messages: List[AllMessageValues], cache_key: str
) -> CachedContentRequestBody:
supports_system_message = get_supports_system_message(
model=model, custom_llm_provider="gemini"
)
transformed_system_messages, new_messages = _transform_system_message(
supports_system_message=supports_system_message, messages=messages
)
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
data = CachedContentRequestBody(
contents=transformed_messages,
model="models/{}".format(model),
displayName=cache_key,
)
if transformed_system_messages is not None:
data["system_instruction"] = transformed_system_messages
return data