diff --git a/litellm/__init__.py b/litellm/__init__.py index eb8fd56d0a..03ca0482af 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -848,15 +848,21 @@ from .llms.gemini import GeminiConfig from .llms.nlp_cloud import NLPCloudConfig from .llms.aleph_alpha import AlephAlphaConfig from .llms.petals import PetalsConfig -from .llms.vertex_httpx import ( +from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( VertexGeminiConfig, GoogleAIStudioGeminiConfig, VertexAIConfig, ) -from .llms.vertex_ai import VertexAITextEmbeddingConfig -from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig -from .llms.vertex_ai_partner import VertexAILlama3Config -from .llms.sagemaker.sagemaker import SagemakerConfig +from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import ( + VertexAITextEmbeddingConfig, +) +from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( + VertexAIAnthropicConfig, +) +from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import ( + VertexAILlama3Config, +) +from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig diff --git a/litellm/llms/fine_tuning_apis/vertex_ai.py b/litellm/llms/fine_tuning_apis/vertex_ai.py index ffdb82c5b4..e87a9bf3c4 100644 --- a/litellm/llms/fine_tuning_apis/vertex_ai.py +++ b/litellm/llms/fine_tuning_apis/vertex_ai.py @@ -8,7 +8,9 @@ from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparamet from litellm._logging import verbose_logger from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.vertex_httpx import VertexLLM +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + VertexLLM, +) from litellm.types.llms.openai import FineTuningJobCreate from litellm.types.llms.vertex_ai import ( FineTuneJobCreate, diff --git a/litellm/llms/text_to_speech/vertex_ai.py b/litellm/llms/text_to_speech/vertex_ai.py index 168a9eeb09..b9fca53250 100644 --- a/litellm/llms/text_to_speech/vertex_ai.py +++ b/litellm/llms/text_to_speech/vertex_ai.py @@ -13,7 +13,9 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, ) from litellm.llms.openai import HttpxBinaryResponseContent -from litellm.llms.vertex_httpx import VertexLLM +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + VertexLLM, +) class VertexInput(TypedDict, total=False): diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py new file mode 100644 index 0000000000..8faf7a3afa --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -0,0 +1,39 @@ +from typing import Literal + +import httpx + +from litellm import supports_system_messages, verbose_logger + + +class VertexAIError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url=" https://cloud.google.com/vertex-ai/" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +def get_supports_system_message( + model: str, custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"] +) -> bool: + try: + _custom_llm_provider = custom_llm_provider + if custom_llm_provider == "vertex_ai_beta": + _custom_llm_provider = "vertex_ai" + supports_system_message = supports_system_messages( + model=model, custom_llm_provider=_custom_llm_provider + ) + except Exception as e: + verbose_logger.warning( + "Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format( + str(e) + ) + ) + supports_system_message = False + + return supports_system_message diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py new file mode 100644 index 0000000000..37c89838c5 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py @@ -0,0 +1,88 @@ +""" +Transformation logic for context caching. + +Why separate file? Make it easy to see how transformation works +""" + +from typing import List, Tuple + +from litellm.types.llms.openai import AllMessageValues +from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstructions +from litellm.utils import is_cached_message + +from ..common_utils import VertexAIError, get_supports_system_message +from ..gemini_transformation import transform_system_message +from ..vertex_and_google_ai_studio_gemini import _gemini_convert_messages_with_history + + +def separate_cached_messages( + messages: List[AllMessageValues], +) -> Tuple[List[AllMessageValues], List[AllMessageValues]]: + """ + Returns separated cached and non-cached messages. + + Args: + messages: List of messages to be separated. + + Returns: + Tuple containing: + - cached_messages: List of cached messages. + - non_cached_messages: List of non-cached messages. + """ + cached_messages: List[AllMessageValues] = [] + non_cached_messages: List[AllMessageValues] = [] + + # Extract cached messages and their indices + filtered_messages: List[Tuple[int, AllMessageValues]] = [] + for idx, message in enumerate(messages): + if is_cached_message(message=message): + filtered_messages.append((idx, message)) + + # Validate only one block of continuous cached messages + if len(filtered_messages) > 1: + expected_idx = filtered_messages[0][0] + 1 + for idx, _ in filtered_messages[1:]: + if idx != expected_idx: + raise VertexAIError( + status_code=422, + message="Gemini Context Caching only supports 1 message/block of continuous messages. Your idx, messages were - {}".format( + filtered_messages + ), + ) + expected_idx += 1 + + # Separate messages based on the block of cached messages + if filtered_messages: + first_cached_idx = filtered_messages[0][0] + last_cached_idx = filtered_messages[-1][0] + + cached_messages = messages[first_cached_idx : last_cached_idx + 1] + non_cached_messages = ( + messages[:first_cached_idx] + messages[last_cached_idx + 1 :] + ) + else: + non_cached_messages = messages + + return cached_messages, non_cached_messages + + +def transform_openai_messages_to_gemini_context_caching( + model: str, + messages: List[AllMessageValues], +) -> CachedContentRequestBody: + supports_system_message = get_supports_system_message( + model=model, custom_llm_provider="gemini" + ) + + transformed_system_messages, new_messages = transform_system_message( + supports_system_message=supports_system_message, messages=messages + ) + + transformed_messages = _gemini_convert_messages_with_history(messages=new_messages) + data = CachedContentRequestBody( + contents=transformed_messages, model="models/{}".format(model) + ) + if transformed_system_messages is not None: + data["system_instruction"] = transformed_system_messages + + return data diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py new file mode 100644 index 0000000000..6461fe04b8 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py @@ -0,0 +1,170 @@ +import types +from typing import Callable, List, Literal, Optional, Tuple, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.openai import AllMessageValues +from litellm.types.llms.vertex_ai import ( + RequestBody, + VertexAICachedContentResponseObject, +) +from litellm.utils import ModelResponse + +from ..common_utils import VertexAIError +from .transformation import ( + separate_cached_messages, + transform_openai_messages_to_gemini_context_caching, +) + + +class ContextCachingEndpoints: + """ + Covers context caching endpoints for Vertex AI + Google AI Studio + + v0: covers Google AI Studio + """ + + def __init__(self) -> None: + pass + + def _get_token_and_url( + self, + model: str, + gemini_api_key: Optional[str], + custom_llm_provider: Literal["gemini"], + api_base: Optional[str], + ) -> Tuple[Optional[str], str]: + """ + Internal function. Returns the token and url for the call. + + Handles logic if it's google ai studio vs. vertex ai. + + Returns + token, url + """ + if custom_llm_provider == "gemini": + _gemini_model_name = "models/{}".format(model) + auth_header = None + endpoint = "cachedContents" + url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format( + endpoint, gemini_api_key + ) + + else: + raise NotImplementedError + if ( + api_base is not None + ): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317 + if custom_llm_provider == "gemini": + url = "{}/{}".format(api_base, endpoint) + auth_header = ( + gemini_api_key # cloudflare expects api key as bearer token + ) + else: + url = "{}:{}".format(api_base, endpoint) + + return auth_header, url + + def create_cache( + self, + messages: List[AllMessageValues], # receives openai format messages + api_key: str, + api_base: Optional[str], + model: str, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]], + timeout: Optional[Union[float, httpx.Timeout]], + logging_obj: Logging, + extra_headers: Optional[dict] = None, + cached_content: Optional[str] = None, + ) -> Tuple[List[AllMessageValues], Optional[str]]: + """ + Receives + - messages: List of dict - messages in the openai format + + Returns + - messages - List[dict] - filtered list of messages in the openai format. + - cached_content - str - the cache content id, to be passed in the gemini request body + + Follows - https://ai.google.dev/api/caching#request-body + """ + if cached_content is not None: + return messages, cached_content + + ## AUTHORIZATION ## + token, url = self._get_token_and_url( + model=model, + gemini_api_key=api_key, + custom_llm_provider="gemini", + api_base=api_base, + ) + + headers = { + "Content-Type": "application/json", + } + if token is not None: + headers["Authorization"] = f"Bearer {token}" + if extra_headers is not None: + headers.update(extra_headers) + + if client is None or not isinstance(client, HTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = HTTPHandler(**_params) # type: ignore + else: + client = client + + cached_messages, non_cached_messages = separate_cached_messages( + messages=messages + ) + + if len(cached_messages) == 0: + return messages, None + + cached_content_request_body = ( + transform_openai_messages_to_gemini_context_caching( + model=model, messages=cached_messages + ) + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": cached_content_request_body, + "api_base": url, + "headers": headers, + }, + ) + + try: + response = client.post( + url=url, headers=headers, json=cached_content_request_body # type: ignore + ) + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + raw_response_cached = response.json() + cached_content_response_obj = VertexAICachedContentResponseObject( + name=raw_response_cached.get("name"), model=raw_response_cached.get("model") + ) + return (non_cached_messages, cached_content_response_obj["name"]) + + def async_create_cache(self): + pass + + def get_cache(self): + pass + + async def async_get_cache(self): + pass diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini_transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_transformation.py new file mode 100644 index 0000000000..a25cd2db7a --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_transformation.py @@ -0,0 +1,47 @@ +""" +Transformation logic from OpenAI format to Gemini format. + +Why separate file? Make it easy to see how transformation works +""" + +from typing import List, Optional, Tuple + +from litellm.types.llms.openai import AllMessageValues +from litellm.types.llms.vertex_ai import PartType, SystemInstructions + + +def transform_system_message( + supports_system_message: bool, messages: List[AllMessageValues] +) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]: + """ + Extracts the system message from the openai message list. + + Converts the system message to Gemini format + + Returns + - system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format. + - messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately) + """ + # Separate system prompt from rest of message + system_prompt_indices = [] + system_content_blocks: List[PartType] = [] + if supports_system_message is True: + for idx, message in enumerate(messages): + if message["role"] == "system": + if isinstance(message["content"], str): + _system_content_block = PartType(text=message["content"]) + elif isinstance(message["content"], list): + system_text = "" + for content in message["content"]: + system_text += content.get("text") or "" + _system_content_block = PartType(text=system_text) + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + + if len(system_content_blocks) > 0: + return SystemInstructions(parts=system_content_blocks), messages + + return None, messages diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py similarity index 99% rename from litellm/llms/vertex_ai_anthropic.py rename to litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py index 5887458527..bade93c2f1 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py @@ -26,7 +26,7 @@ from litellm.types.llms.openai import ( from litellm.types.utils import ResponseFormatChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage -from .prompt_templates.factory import ( +from ..prompt_templates.factory import ( construct_tool_use_system_prompt, contains_tag, custom_prompt, diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py similarity index 100% rename from litellm/llms/vertex_ai.py rename to litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py diff --git a/litellm/llms/vertex_ai_partner.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models.py similarity index 89% rename from litellm/llms/vertex_ai_partner.py rename to litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models.py index 24586a3fe4..3521ba06d4 100644 --- a/litellm/llms/vertex_ai_partner.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models.py @@ -1,41 +1,14 @@ # What is this? ## Handler for calling llama 3.1 API on Vertex AI -import copy -import json -import os -import time import types -import uuid -from enum import Enum -from typing import Any, Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, Literal, Optional, Union import httpx # type: ignore -import requests # type: ignore import litellm -from litellm.litellm_core_utils.core_helpers import map_finish_reason -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.llms.anthropic import ( - AnthropicMessagesTool, - AnthropicMessagesToolChoice, -) -from litellm.types.llms.openai import ( - ChatCompletionToolParam, - ChatCompletionToolParamFunctionChunk, -) -from litellm.types.utils import ResponseFormatChunk -from litellm.utils import CustomStreamWrapper, ModelResponse, Usage +from litellm.utils import ModelResponse -from .base import BaseLLM -from .prompt_templates.factory import ( - construct_tool_use_system_prompt, - contains_tag, - custom_prompt, - extract_between_tags, - parse_xml_params, - prompt_factory, - response_schema_prompt, -) +from ..base import BaseLLM class VertexAIError(Exception): diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py similarity index 97% rename from litellm/llms/vertex_httpx.py rename to litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py index 2960dd82f9..3e6751d14e 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py @@ -9,7 +9,7 @@ import types import uuid from enum import Enum from functools import partial -from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore @@ -25,7 +25,9 @@ from litellm.llms.prompt_templates.factory import ( convert_url_to_base64, response_schema_prompt, ) -from litellm.llms.vertex_ai import _gemini_convert_messages_with_history +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import ( + _gemini_convert_messages_with_history, +) from litellm.types.llms.openai import ( ChatCompletionResponseMessage, ChatCompletionToolCallChunk, @@ -52,7 +54,12 @@ from litellm.types.llms.vertex_ai import ( from litellm.types.utils import GenericStreamingChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage -from .base import BaseLLM +from ..base import BaseLLM +from .common_utils import VertexAIError, get_supports_system_message +from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints +from .gemini_transformation import transform_system_message + +context_caching_endpoints = ContextCachingEndpoints() class VertexAIConfig: @@ -789,19 +796,6 @@ def make_sync_call( return completion_stream -class VertexAIError(Exception): - def __init__(self, status_code, message): - self.status_code = status_code - self.message = message - self.request = httpx.Request( - method="POST", url=" https://cloud.google.com/vertex-ai/" - ) - self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs - - class VertexLLM(BaseLLM): def __init__(self) -> None: super().__init__() @@ -1366,33 +1360,27 @@ class VertexLLM(BaseLLM): ) ## TRANSFORMATION ## - try: - _custom_llm_provider = custom_llm_provider - if custom_llm_provider == "vertex_ai_beta": - _custom_llm_provider = "vertex_ai" - supports_system_message = litellm.supports_system_messages( - model=model, custom_llm_provider=_custom_llm_provider + ### CHECK CONTEXT CACHING ### + if gemini_api_key is not None: + messages, cached_content = context_caching_endpoints.create_cache( + messages=messages, + api_key=gemini_api_key, + api_base=api_base, + model=model, + client=client, + timeout=timeout, + extra_headers=extra_headers, + cached_content=optional_params.pop("cached_content", None), + logging_obj=logging_obj, ) - except Exception as e: - verbose_logger.warning( - "Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format( - str(e) - ) - ) - supports_system_message = False - # Separate system prompt from rest of message - system_prompt_indices = [] - system_content_blocks: List[PartType] = [] - if supports_system_message is True: - for idx, message in enumerate(messages): - if message["role"] == "system": - _system_content_block = PartType(text=message["content"]) - system_content_blocks.append(_system_content_block) - system_prompt_indices.append(idx) - if len(system_prompt_indices) > 0: - for idx in reversed(system_prompt_indices): - messages.pop(idx) + # Separate system prompt from rest of message + supports_system_message = get_supports_system_message( + model=model, custom_llm_provider=custom_llm_provider + ) + system_instructions, messages = transform_system_message( + supports_system_message=supports_system_message, messages=messages + ) # Checks for 'response_schema' support - if passed in if "response_schema" in optional_params: supports_response_schema = litellm.supports_response_schema( @@ -1426,13 +1414,11 @@ class VertexLLM(BaseLLM): safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( "safety_settings", None ) # type: ignore - cached_content: Optional[str] = optional_params.pop("cached_content", None) generation_config: Optional[GenerationConfig] = GenerationConfig( **optional_params ) data = RequestBody(contents=content) - if len(system_content_blocks) > 0: - system_instructions = SystemInstructions(parts=system_content_blocks) + if system_instructions is not None: data["system_instruction"] = system_instructions if tools is not None: data["tools"] = tools diff --git a/litellm/main.py b/litellm/main.py index 86ca1ce8a9..5ef8727d5b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -95,8 +95,6 @@ from .llms import ( replicate, together_ai, triton, - vertex_ai, - vertex_ai_anthropic, vllm, watsonx, ) @@ -124,8 +122,16 @@ from .llms.sagemaker.sagemaker import SagemakerLLM from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI from .llms.triton import TritonChatCompletion -from .llms.vertex_ai_partner import VertexAIPartnerModels -from .llms.vertex_httpx import VertexLLM +from .llms.vertex_ai_and_google_ai_studio import ( + vertex_ai_anthropic, + vertex_ai_non_gemini, +) +from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import ( + VertexAIPartnerModels, +) +from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + VertexLLM, +) from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent from .types.utils import ( @@ -2112,7 +2118,7 @@ def completion( extra_headers=extra_headers, ) else: - model_response = vertex_ai.completion( + model_response = vertex_ai_non_gemini.completion( model=model, messages=messages, model_response=model_response, @@ -3558,7 +3564,7 @@ def embedding( print_verbose=print_verbose, ) else: - response = vertex_ai.embedding( + response = vertex_ai_non_gemini.embedding( model=model, input=input, encoding=encoding, diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 9a8ce48462..b3d14f0a17 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -28,7 +28,9 @@ from litellm import ( completion_cost, embedding, ) -from litellm.llms.vertex_ai import _gemini_convert_messages_with_history +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + _gemini_convert_messages_with_history, +) from litellm.tests.test_streaming import streaming_format_tests litellm.num_retries = 3 @@ -2199,3 +2201,137 @@ async def test_completion_fine_tuned_model(): # Optional: Print for debugging print("Arguments passed to Vertex AI:", args_to_vertexai) print("Response:", response) + + +def mock_gemini_request(*args, **kwargs): + print(f"kwargs: {kwargs}") + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + if "cachedContents" in kwargs["url"]: + mock_response.json.return_value = { + "name": "cachedContents/4d2kd477o3pg", + "model": "models/gemini-1.5-flash-001", + "createTime": "2024-08-26T22:31:16.147190Z", + "updateTime": "2024-08-26T22:31:16.147190Z", + "expireTime": "2024-08-26T22:36:15.548934784Z", + "displayName": "", + "usageMetadata": {"totalTokenCount": 323383}, + } + else: + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "Please provide me with the text of the legal agreement" + } + ], + "role": "model", + }, + "finishReason": "MAX_TOKENS", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + }, + ], + } + ], + "usageMetadata": { + "promptTokenCount": 40049, + "candidatesTokenCount": 10, + "totalTokenCount": 40059, + "cachedContentTokenCount": 40012, + }, + } + + return mock_response + + +@pytest.mark.asyncio +async def test_gemini_context_caching_anthropic_format(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + litellm.set_verbose = True + client = HTTPHandler(concurrent_limit=1) + with patch.object(client, "post", side_effect=mock_gemini_request) as mock_client: + try: + response = litellm.completion( + model="gemini/gemini-1.5-flash-001", + messages=[ + # System Message + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" + * 4000, + "cache_control": {"type": "ephemeral"}, + } + ], + }, + # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", + }, + # The final turn is marked with cache-control, for continuing in followups. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + } + ], + }, + ], + temperature=0.2, + max_tokens=10, + client=client, + ) + + except Exception as e: + print(e) + + assert mock_client.call_count == 2 + + first_call_args = mock_client.call_args_list[0].kwargs + + print(f"first_call_args: {first_call_args}") + + assert "cachedContents" in first_call_args["url"] + + # assert "cache_read_input_tokens" in response.usage + # assert "cache_creation_input_tokens" in response.usage + + # # Assert either a cache entry was created or cache was read - changes depending on the anthropic api ttl + # assert (response.usage.cache_read_input_tokens > 0) or ( + # response.usage.cache_creation_input_tokens > 0 + # ) diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index ce1bd64fa8..470f72c5b6 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -325,11 +325,21 @@ class ChatCompletionDeltaToolCallChunk(TypedDict, total=False): index: int -class ChatCompletionTextObject(TypedDict): +class ChatCompletionCachedContent(TypedDict): + type: Literal["ephemeral"] + + +class OpenAIChatCompletionTextObject(TypedDict): type: Literal["text"] text: str +class ChatCompletionTextObject( + OpenAIChatCompletionTextObject, total=False +): # litellm wrapper on top of openai object for handling cached content + cache_control: ChatCompletionCachedContent + + class ChatCompletionImageUrlObject(TypedDict, total=False): url: Required[str] detail: str diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 74acd4fec4..90730d75fe 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -186,6 +186,17 @@ class RequestBody(TypedDict, total=False): cachedContent: str +class CachedContentRequestBody(TypedDict, total=False): + contents: Required[List[ContentType]] + system_instruction: SystemInstructions + tools: Tools + toolConfig: ToolConfig + model: Required[str] # Format: models/{model} + ttl: str # ending in 's' - Example: "3.5s". + name: str # Format: cachedContents/{id} + displayName: str + + class SafetyRatings(TypedDict): category: HarmCategory probability: HarmProbability @@ -320,3 +331,8 @@ class Instance(TypedDict, total=False): class VertexMultimodalEmbeddingRequest(TypedDict, total=False): instances: List[Instance] + + +class VertexAICachedContentResponseObject(TypedDict): + name: str + model: str diff --git a/litellm/utils.py b/litellm/utils.py index 3187bb9cb1..fe7f101250 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -69,6 +69,7 @@ from litellm.litellm_core_utils.redact_messages import ( from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.llms.openai import ( + AllMessageValues, ChatCompletionNamedToolChoiceParam, ChatCompletionToolParam, ) @@ -11549,3 +11550,25 @@ class ModelResponseListIterator: class CustomModelResponseIterator(Iterable): def __init__(self) -> None: super().__init__() + + +def is_cached_message(message: AllMessageValues) -> bool: + """ + Returns true, if message is marked as needing to be cached. + + Used for anthropic/gemini context caching. + + Follows the anthropic format {"cache_control": {"type": "ephemeral"}} + """ + if message["content"] is None or isinstance(message["content"], str): + return False + + for content in message["content"]: + if ( + content["type"] == "text" + and content.get("cache_control") is not None + and content["cache_control"]["type"] == "ephemeral" # type: ignore + ): + return True + + return False