mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
* remove unused imports * fix AmazonConverseConfig * fix test * fix import * ruff check fixes * test fixes * fix testing * fix imports
91 lines
3 KiB
Python
91 lines
3 KiB
Python
"""
|
|
Transformation logic for context caching.
|
|
|
|
Why separate file? Make it easy to see how transformation works
|
|
"""
|
|
|
|
from typing import List, Tuple
|
|
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.llms.vertex_ai import CachedContentRequestBody
|
|
from litellm.utils import is_cached_message
|
|
|
|
from ..common_utils import VertexAIError, get_supports_system_message
|
|
from ..gemini.transformation import (
|
|
_gemini_convert_messages_with_history,
|
|
_transform_system_message,
|
|
)
|
|
|
|
|
|
def separate_cached_messages(
|
|
messages: List[AllMessageValues],
|
|
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
|
|
"""
|
|
Returns separated cached and non-cached messages.
|
|
|
|
Args:
|
|
messages: List of messages to be separated.
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- cached_messages: List of cached messages.
|
|
- non_cached_messages: List of non-cached messages.
|
|
"""
|
|
cached_messages: List[AllMessageValues] = []
|
|
non_cached_messages: List[AllMessageValues] = []
|
|
|
|
# Extract cached messages and their indices
|
|
filtered_messages: List[Tuple[int, AllMessageValues]] = []
|
|
for idx, message in enumerate(messages):
|
|
if is_cached_message(message=message):
|
|
filtered_messages.append((idx, message))
|
|
|
|
# Validate only one block of continuous cached messages
|
|
if len(filtered_messages) > 1:
|
|
expected_idx = filtered_messages[0][0] + 1
|
|
for idx, _ in filtered_messages[1:]:
|
|
if idx != expected_idx:
|
|
raise VertexAIError(
|
|
status_code=422,
|
|
message="Gemini Context Caching only supports 1 message/block of continuous messages. Your idx, messages were - {}".format(
|
|
filtered_messages
|
|
),
|
|
)
|
|
expected_idx += 1
|
|
|
|
# Separate messages based on the block of cached messages
|
|
if filtered_messages:
|
|
first_cached_idx = filtered_messages[0][0]
|
|
last_cached_idx = filtered_messages[-1][0]
|
|
|
|
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
|
|
non_cached_messages = (
|
|
messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
|
|
)
|
|
else:
|
|
non_cached_messages = messages
|
|
|
|
return cached_messages, non_cached_messages
|
|
|
|
|
|
def transform_openai_messages_to_gemini_context_caching(
|
|
model: str, messages: List[AllMessageValues], cache_key: str
|
|
) -> CachedContentRequestBody:
|
|
supports_system_message = get_supports_system_message(
|
|
model=model, custom_llm_provider="gemini"
|
|
)
|
|
|
|
transformed_system_messages, new_messages = _transform_system_message(
|
|
supports_system_message=supports_system_message, messages=messages
|
|
)
|
|
|
|
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
|
|
data = CachedContentRequestBody(
|
|
contents=transformed_messages,
|
|
model="models/{}".format(model),
|
|
displayName=cache_key,
|
|
)
|
|
if transformed_system_messages is not None:
|
|
data["system_instruction"] = transformed_system_messages
|
|
|
|
return data
|