feat(vertex_ai_context_caching.py): support making context caching calls to vertex ai in a normal chat completion call (anthropic caching format)

Closes https://github.com/BerriAI/litellm/issues/5213
This commit is contained in:
Krrish Dholakia 2024-08-26 18:47:45 -07:00
parent 6ff17f1acd
commit 074e30fa10
16 changed files with 594 additions and 90 deletions

View file

@ -848,15 +848,21 @@ from .llms.gemini import GeminiConfig
from .llms.nlp_cloud import NLPCloudConfig from .llms.nlp_cloud import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig from .llms.petals import PetalsConfig
from .llms.vertex_httpx import ( from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig, VertexGeminiConfig,
GoogleAIStudioGeminiConfig, GoogleAIStudioGeminiConfig,
VertexAIConfig, VertexAIConfig,
) )
from .llms.vertex_ai import VertexAITextEmbeddingConfig from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig VertexAITextEmbeddingConfig,
from .llms.vertex_ai_partner import VertexAILlama3Config )
from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
VertexAIAnthropicConfig,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
VertexAILlama3Config,
)
from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig

View file

@ -8,7 +8,9 @@ from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparamet
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.vertex_httpx import VertexLLM from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from litellm.types.llms.openai import FineTuningJobCreate from litellm.types.llms.openai import FineTuningJobCreate
from litellm.types.llms.vertex_ai import ( from litellm.types.llms.vertex_ai import (
FineTuneJobCreate, FineTuneJobCreate,

View file

@ -13,7 +13,9 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
) )
from litellm.llms.openai import HttpxBinaryResponseContent from litellm.llms.openai import HttpxBinaryResponseContent
from litellm.llms.vertex_httpx import VertexLLM from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
class VertexInput(TypedDict, total=False): class VertexInput(TypedDict, total=False):

View file

@ -0,0 +1,39 @@
from typing import Literal
import httpx
from litellm import supports_system_messages, verbose_logger
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def get_supports_system_message(
model: str, custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"]
) -> bool:
try:
_custom_llm_provider = custom_llm_provider
if custom_llm_provider == "vertex_ai_beta":
_custom_llm_provider = "vertex_ai"
supports_system_message = supports_system_messages(
model=model, custom_llm_provider=_custom_llm_provider
)
except Exception as e:
verbose_logger.warning(
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
str(e)
)
)
supports_system_message = False
return supports_system_message

View file

@ -0,0 +1,88 @@
"""
Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works
"""
from typing import List, Tuple
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstructions
from litellm.utils import is_cached_message
from ..common_utils import VertexAIError, get_supports_system_message
from ..gemini_transformation import transform_system_message
from ..vertex_and_google_ai_studio_gemini import _gemini_convert_messages_with_history
def separate_cached_messages(
messages: List[AllMessageValues],
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
"""
Returns separated cached and non-cached messages.
Args:
messages: List of messages to be separated.
Returns:
Tuple containing:
- cached_messages: List of cached messages.
- non_cached_messages: List of non-cached messages.
"""
cached_messages: List[AllMessageValues] = []
non_cached_messages: List[AllMessageValues] = []
# Extract cached messages and their indices
filtered_messages: List[Tuple[int, AllMessageValues]] = []
for idx, message in enumerate(messages):
if is_cached_message(message=message):
filtered_messages.append((idx, message))
# Validate only one block of continuous cached messages
if len(filtered_messages) > 1:
expected_idx = filtered_messages[0][0] + 1
for idx, _ in filtered_messages[1:]:
if idx != expected_idx:
raise VertexAIError(
status_code=422,
message="Gemini Context Caching only supports 1 message/block of continuous messages. Your idx, messages were - {}".format(
filtered_messages
),
)
expected_idx += 1
# Separate messages based on the block of cached messages
if filtered_messages:
first_cached_idx = filtered_messages[0][0]
last_cached_idx = filtered_messages[-1][0]
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
non_cached_messages = (
messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
)
else:
non_cached_messages = messages
return cached_messages, non_cached_messages
def transform_openai_messages_to_gemini_context_caching(
model: str,
messages: List[AllMessageValues],
) -> CachedContentRequestBody:
supports_system_message = get_supports_system_message(
model=model, custom_llm_provider="gemini"
)
transformed_system_messages, new_messages = transform_system_message(
supports_system_message=supports_system_message, messages=messages
)
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
data = CachedContentRequestBody(
contents=transformed_messages, model="models/{}".format(model)
)
if transformed_system_messages is not None:
data["system_instruction"] = transformed_system_messages
return data

View file

@ -0,0 +1,170 @@
import types
from typing import Callable, List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import (
RequestBody,
VertexAICachedContentResponseObject,
)
from litellm.utils import ModelResponse
from ..common_utils import VertexAIError
from .transformation import (
separate_cached_messages,
transform_openai_messages_to_gemini_context_caching,
)
class ContextCachingEndpoints:
"""
Covers context caching endpoints for Vertex AI + Google AI Studio
v0: covers Google AI Studio
"""
def __init__(self) -> None:
pass
def _get_token_and_url(
self,
model: str,
gemini_api_key: Optional[str],
custom_llm_provider: Literal["gemini"],
api_base: Optional[str],
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None
endpoint = "cachedContents"
url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
endpoint, gemini_api_key
)
else:
raise NotImplementedError
if (
api_base is not None
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
if custom_llm_provider == "gemini":
url = "{}/{}".format(api_base, endpoint)
auth_header = (
gemini_api_key # cloudflare expects api key as bearer token
)
else:
url = "{}:{}".format(api_base, endpoint)
return auth_header, url
def create_cache(
self,
messages: List[AllMessageValues], # receives openai format messages
api_key: str,
api_base: Optional[str],
model: str,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]],
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: Logging,
extra_headers: Optional[dict] = None,
cached_content: Optional[str] = None,
) -> Tuple[List[AllMessageValues], Optional[str]]:
"""
Receives
- messages: List of dict - messages in the openai format
Returns
- messages - List[dict] - filtered list of messages in the openai format.
- cached_content - str - the cache content id, to be passed in the gemini request body
Follows - https://ai.google.dev/api/caching#request-body
"""
if cached_content is not None:
return messages, cached_content
## AUTHORIZATION ##
token, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
custom_llm_provider="gemini",
api_base=api_base,
)
headers = {
"Content-Type": "application/json",
}
if token is not None:
headers["Authorization"] = f"Bearer {token}"
if extra_headers is not None:
headers.update(extra_headers)
if client is None or not isinstance(client, HTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
cached_messages, non_cached_messages = separate_cached_messages(
messages=messages
)
if len(cached_messages) == 0:
return messages, None
cached_content_request_body = (
transform_openai_messages_to_gemini_context_caching(
model=model, messages=cached_messages
)
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": cached_content_request_body,
"api_base": url,
"headers": headers,
},
)
try:
response = client.post(
url=url, headers=headers, json=cached_content_request_body # type: ignore
)
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
raw_response_cached = response.json()
cached_content_response_obj = VertexAICachedContentResponseObject(
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
)
return (non_cached_messages, cached_content_response_obj["name"])
def async_create_cache(self):
pass
def get_cache(self):
pass
async def async_get_cache(self):
pass

View file

@ -0,0 +1,47 @@
"""
Transformation logic from OpenAI format to Gemini format.
Why separate file? Make it easy to see how transformation works
"""
from typing import List, Optional, Tuple
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import PartType, SystemInstructions
def transform_system_message(
supports_system_message: bool, messages: List[AllMessageValues]
) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]:
"""
Extracts the system message from the openai message list.
Converts the system message to Gemini format
Returns
- system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format.
- messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately)
"""
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[PartType] = []
if supports_system_message is True:
for idx, message in enumerate(messages):
if message["role"] == "system":
if isinstance(message["content"], str):
_system_content_block = PartType(text=message["content"])
elif isinstance(message["content"], list):
system_text = ""
for content in message["content"]:
system_text += content.get("text") or ""
_system_content_block = PartType(text=system_text)
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_content_blocks) > 0:
return SystemInstructions(parts=system_content_blocks), messages
return None, messages

View file

@ -26,7 +26,7 @@ from litellm.types.llms.openai import (
from litellm.types.utils import ResponseFormatChunk from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .prompt_templates.factory import ( from ..prompt_templates.factory import (
construct_tool_use_system_prompt, construct_tool_use_system_prompt,
contains_tag, contains_tag,
custom_prompt, custom_prompt,

View file

@ -1,41 +1,14 @@
# What is this? # What is this?
## Handler for calling llama 3.1 API on Vertex AI ## Handler for calling llama 3.1 API on Vertex AI
import copy
import json
import os
import time
import types import types
import uuid from typing import Callable, Literal, Optional, Union
from enum import Enum
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.utils import ModelResponse
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM from ..base import BaseLLM
from .prompt_templates.factory import (
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
class VertexAIError(Exception): class VertexAIError(Exception):

View file

@ -9,7 +9,7 @@ import types
import uuid import uuid
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -25,7 +25,9 @@ from litellm.llms.prompt_templates.factory import (
convert_url_to_base64, convert_url_to_base64,
response_schema_prompt, response_schema_prompt,
) )
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
_gemini_convert_messages_with_history,
)
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionResponseMessage, ChatCompletionResponseMessage,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
@ -52,7 +54,12 @@ from litellm.types.llms.vertex_ai import (
from litellm.types.utils import GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM from ..base import BaseLLM
from .common_utils import VertexAIError, get_supports_system_message
from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints
from .gemini_transformation import transform_system_message
context_caching_endpoints = ContextCachingEndpoints()
class VertexAIConfig: class VertexAIConfig:
@ -789,19 +796,6 @@ def make_sync_call(
return completion_stream return completion_stream
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexLLM(BaseLLM): class VertexLLM(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -1366,33 +1360,27 @@ class VertexLLM(BaseLLM):
) )
## TRANSFORMATION ## ## TRANSFORMATION ##
try: ### CHECK CONTEXT CACHING ###
_custom_llm_provider = custom_llm_provider if gemini_api_key is not None:
if custom_llm_provider == "vertex_ai_beta": messages, cached_content = context_caching_endpoints.create_cache(
_custom_llm_provider = "vertex_ai" messages=messages,
supports_system_message = litellm.supports_system_messages( api_key=gemini_api_key,
model=model, custom_llm_provider=_custom_llm_provider api_base=api_base,
model=model,
client=client,
timeout=timeout,
extra_headers=extra_headers,
cached_content=optional_params.pop("cached_content", None),
logging_obj=logging_obj,
) )
except Exception as e:
verbose_logger.warning(
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
str(e)
)
)
supports_system_message = False
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[PartType] = []
if supports_system_message is True:
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = PartType(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
# Separate system prompt from rest of message
supports_system_message = get_supports_system_message(
model=model, custom_llm_provider=custom_llm_provider
)
system_instructions, messages = transform_system_message(
supports_system_message=supports_system_message, messages=messages
)
# Checks for 'response_schema' support - if passed in # Checks for 'response_schema' support - if passed in
if "response_schema" in optional_params: if "response_schema" in optional_params:
supports_response_schema = litellm.supports_response_schema( supports_response_schema = litellm.supports_response_schema(
@ -1426,13 +1414,11 @@ class VertexLLM(BaseLLM):
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
"safety_settings", None "safety_settings", None
) # type: ignore ) # type: ignore
cached_content: Optional[str] = optional_params.pop("cached_content", None)
generation_config: Optional[GenerationConfig] = GenerationConfig( generation_config: Optional[GenerationConfig] = GenerationConfig(
**optional_params **optional_params
) )
data = RequestBody(contents=content) data = RequestBody(contents=content)
if len(system_content_blocks) > 0: if system_instructions is not None:
system_instructions = SystemInstructions(parts=system_content_blocks)
data["system_instruction"] = system_instructions data["system_instruction"] = system_instructions
if tools is not None: if tools is not None:
data["tools"] = tools data["tools"] = tools

View file

@ -95,8 +95,6 @@ from .llms import (
replicate, replicate,
together_ai, together_ai,
triton, triton,
vertex_ai,
vertex_ai_anthropic,
vllm, vllm,
watsonx, watsonx,
) )
@ -124,8 +122,16 @@ from .llms.sagemaker.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_partner import VertexAIPartnerModels from .llms.vertex_ai_and_google_ai_studio import (
from .llms.vertex_httpx import VertexLLM vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
VertexAIPartnerModels,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.watsonx import IBMWatsonXAI from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ( from .types.utils import (
@ -2112,7 +2118,7 @@ def completion(
extra_headers=extra_headers, extra_headers=extra_headers,
) )
else: else:
model_response = vertex_ai.completion( model_response = vertex_ai_non_gemini.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,
@ -3558,7 +3564,7 @@ def embedding(
print_verbose=print_verbose, print_verbose=print_verbose,
) )
else: else:
response = vertex_ai.embedding( response = vertex_ai_non_gemini.embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,

View file

@ -28,7 +28,9 @@ from litellm import (
completion_cost, completion_cost,
embedding, embedding,
) )
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
_gemini_convert_messages_with_history,
)
from litellm.tests.test_streaming import streaming_format_tests from litellm.tests.test_streaming import streaming_format_tests
litellm.num_retries = 3 litellm.num_retries = 3
@ -2199,3 +2201,137 @@ async def test_completion_fine_tuned_model():
# Optional: Print for debugging # Optional: Print for debugging
print("Arguments passed to Vertex AI:", args_to_vertexai) print("Arguments passed to Vertex AI:", args_to_vertexai)
print("Response:", response) print("Response:", response)
def mock_gemini_request(*args, **kwargs):
print(f"kwargs: {kwargs}")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
if "cachedContents" in kwargs["url"]:
mock_response.json.return_value = {
"name": "cachedContents/4d2kd477o3pg",
"model": "models/gemini-1.5-flash-001",
"createTime": "2024-08-26T22:31:16.147190Z",
"updateTime": "2024-08-26T22:31:16.147190Z",
"expireTime": "2024-08-26T22:36:15.548934784Z",
"displayName": "",
"usageMetadata": {"totalTokenCount": 323383},
}
else:
mock_response.json.return_value = {
"candidates": [
{
"content": {
"parts": [
{
"text": "Please provide me with the text of the legal agreement"
}
],
"role": "model",
},
"finishReason": "MAX_TOKENS",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
},
],
}
],
"usageMetadata": {
"promptTokenCount": 40049,
"candidatesTokenCount": 10,
"totalTokenCount": 40059,
"cachedContentTokenCount": 40012,
},
}
return mock_response
@pytest.mark.asyncio
async def test_gemini_context_caching_anthropic_format():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
litellm.set_verbose = True
client = HTTPHandler(concurrent_limit=1)
with patch.object(client, "post", side_effect=mock_gemini_request) as mock_client:
try:
response = litellm.completion(
model="gemini/gemini-1.5-flash-001",
messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 4000,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
}
],
},
],
temperature=0.2,
max_tokens=10,
client=client,
)
except Exception as e:
print(e)
assert mock_client.call_count == 2
first_call_args = mock_client.call_args_list[0].kwargs
print(f"first_call_args: {first_call_args}")
assert "cachedContents" in first_call_args["url"]
# assert "cache_read_input_tokens" in response.usage
# assert "cache_creation_input_tokens" in response.usage
# # Assert either a cache entry was created or cache was read - changes depending on the anthropic api ttl
# assert (response.usage.cache_read_input_tokens > 0) or (
# response.usage.cache_creation_input_tokens > 0
# )

View file

@ -325,11 +325,21 @@ class ChatCompletionDeltaToolCallChunk(TypedDict, total=False):
index: int index: int
class ChatCompletionTextObject(TypedDict): class ChatCompletionCachedContent(TypedDict):
type: Literal["ephemeral"]
class OpenAIChatCompletionTextObject(TypedDict):
type: Literal["text"] type: Literal["text"]
text: str text: str
class ChatCompletionTextObject(
OpenAIChatCompletionTextObject, total=False
): # litellm wrapper on top of openai object for handling cached content
cache_control: ChatCompletionCachedContent
class ChatCompletionImageUrlObject(TypedDict, total=False): class ChatCompletionImageUrlObject(TypedDict, total=False):
url: Required[str] url: Required[str]
detail: str detail: str

View file

@ -186,6 +186,17 @@ class RequestBody(TypedDict, total=False):
cachedContent: str cachedContent: str
class CachedContentRequestBody(TypedDict, total=False):
contents: Required[List[ContentType]]
system_instruction: SystemInstructions
tools: Tools
toolConfig: ToolConfig
model: Required[str] # Format: models/{model}
ttl: str # ending in 's' - Example: "3.5s".
name: str # Format: cachedContents/{id}
displayName: str
class SafetyRatings(TypedDict): class SafetyRatings(TypedDict):
category: HarmCategory category: HarmCategory
probability: HarmProbability probability: HarmProbability
@ -320,3 +331,8 @@ class Instance(TypedDict, total=False):
class VertexMultimodalEmbeddingRequest(TypedDict, total=False): class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
instances: List[Instance] instances: List[Instance]
class VertexAICachedContentResponseObject(TypedDict):
name: str
model: str

View file

@ -69,6 +69,7 @@ from litellm.litellm_core_utils.redact_messages import (
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionNamedToolChoiceParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionToolParam, ChatCompletionToolParam,
) )
@ -11549,3 +11550,25 @@ class ModelResponseListIterator:
class CustomModelResponseIterator(Iterable): class CustomModelResponseIterator(Iterable):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def is_cached_message(message: AllMessageValues) -> bool:
"""
Returns true, if message is marked as needing to be cached.
Used for anthropic/gemini context caching.
Follows the anthropic format {"cache_control": {"type": "ephemeral"}}
"""
if message["content"] is None or isinstance(message["content"], str):
return False
for content in message["content"]:
if (
content["type"] == "text"
and content.get("cache_control") is not None
and content["cache_control"]["type"] == "ephemeral" # type: ignore
):
return True
return False