diff --git a/litellm/caching.py b/litellm/caching.py index 880018c4c2..1ef9757846 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -2236,7 +2236,7 @@ class Cache: if self.namespace is not None and isinstance(self.cache, RedisCache): self.cache.namespace = self.namespace - def get_cache_key(self, *args, **kwargs): + def get_cache_key(self, *args, **kwargs) -> str: """ Get the cache key for the given arguments. diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py index 37c89838c5..944ae00bc9 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py @@ -67,8 +67,7 @@ def separate_cached_messages( def transform_openai_messages_to_gemini_context_caching( - model: str, - messages: List[AllMessageValues], + model: str, messages: List[AllMessageValues], cache_key: str ) -> CachedContentRequestBody: supports_system_message = get_supports_system_message( model=model, custom_llm_provider="gemini" @@ -80,7 +79,9 @@ def transform_openai_messages_to_gemini_context_caching( transformed_messages = _gemini_convert_messages_with_history(messages=new_messages) data = CachedContentRequestBody( - contents=transformed_messages, model="models/{}".format(model) + contents=transformed_messages, + model="models/{}".format(model), + name="cachedContents/{}".format(cache_key), ) if transformed_system_messages is not None: data["system_instruction"] = transformed_system_messages diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py index 6461fe04b8..991883e409 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py @@ -4,6 +4,7 @@ from typing import Callable, List, Literal, Optional, Tuple, Union import httpx import litellm +from litellm.caching import Cache from litellm.litellm_core_utils.litellm_logging import Logging from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.openai import AllMessageValues @@ -19,6 +20,8 @@ from .transformation import ( transform_openai_messages_to_gemini_context_caching, ) +local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function + class ContextCachingEndpoints: """ @@ -32,10 +35,10 @@ class ContextCachingEndpoints: def _get_token_and_url( self, - model: str, gemini_api_key: Optional[str], custom_llm_provider: Literal["gemini"], api_base: Optional[str], + cached_key: Optional[str], ) -> Tuple[Optional[str], str]: """ Internal function. Returns the token and url for the call. @@ -46,12 +49,18 @@ class ContextCachingEndpoints: token, url """ if custom_llm_provider == "gemini": - _gemini_model_name = "models/{}".format(model) auth_header = None endpoint = "cachedContents" - url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format( - endpoint, gemini_api_key - ) + if cached_key is not None: + url = "https://generativelanguage.googleapis.com/v1beta/{}/{}?key={}".format( + endpoint, cached_key, gemini_api_key + ) + else: + url = ( + "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format( + endpoint, gemini_api_key + ) + ) else: raise NotImplementedError @@ -68,7 +77,48 @@ class ContextCachingEndpoints: return auth_header, url - def create_cache( + def check_cache( + self, + cache_key: str, + client: HTTPHandler, + headers: dict, + api_key: str, + api_base: Optional[str], + logging_obj: Logging, + ) -> bool: + """Checks if content already cached.""" + + _, url = self._get_token_and_url( + gemini_api_key=api_key, + custom_llm_provider="gemini", + api_base=api_base, + cached_key=cache_key, + ) + try: + ## LOGGING + logging_obj.pre_call( + input="", + api_key="", + additional_args={ + "complete_input_dict": {}, + "api_base": url, + "headers": headers, + }, + ) + + resp = client.get(url=url, headers=headers) + resp.raise_for_status() + return True + except httpx.HTTPStatusError as e: + if e.response.status_code == 403: + return False + raise VertexAIError( + status_code=e.response.status_code, message=e.response.text + ) + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) + + def check_and_create_cache( self, messages: List[AllMessageValues], # receives openai format messages api_key: str, @@ -95,10 +145,10 @@ class ContextCachingEndpoints: ## AUTHORIZATION ## token, url = self._get_token_and_url( - model=model, gemini_api_key=api_key, custom_llm_provider="gemini", api_base=api_base, + cached_key=None, ) headers = { @@ -126,9 +176,23 @@ class ContextCachingEndpoints: if len(cached_messages) == 0: return messages, None + ## CHECK IF CACHED ALREADY + generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages) + cache_exists = self.check_cache( + cache_key=generated_cache_key, + client=client, + headers=headers, + api_key=api_key, + api_base=api_base, + logging_obj=logging_obj, + ) + if cache_exists: + return non_cached_messages, generated_cache_key + + ## TRANSFORM REQUEST cached_content_request_body = ( transform_openai_messages_to_gemini_context_caching( - model=model, messages=cached_messages + model=model, messages=cached_messages, cache_key=generated_cache_key ) ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py index 3e6751d14e..b5e32715bf 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py @@ -1362,7 +1362,7 @@ class VertexLLM(BaseLLM): ## TRANSFORMATION ## ### CHECK CONTEXT CACHING ### if gemini_api_key is not None: - messages, cached_content = context_caching_endpoints.create_cache( + messages, cached_content = context_caching_endpoints.check_and_create_cache( messages=messages, api_key=gemini_api_key, api_base=api_base, diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index b3d14f0a17..33f4df01b1 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -2263,6 +2263,46 @@ def mock_gemini_request(*args, **kwargs): return mock_response +gemini_context_caching_messages = [ + # System Message + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" * 4000, + "cache_control": {"type": "ephemeral"}, + } + ], + }, + # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", + }, + # The final turn is marked with cache-control, for continuing in followups. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + } + ], + }, +] + + @pytest.mark.asyncio async def test_gemini_context_caching_anthropic_format(): from litellm.llms.custom_httpx.http_handler import HTTPHandler @@ -2273,45 +2313,7 @@ async def test_gemini_context_caching_anthropic_format(): try: response = litellm.completion( model="gemini/gemini-1.5-flash-001", - messages=[ - # System Message - { - "role": "system", - "content": [ - { - "type": "text", - "text": "Here is the full text of a complex legal agreement" - * 4000, - "cache_control": {"type": "ephemeral"}, - } - ], - }, - # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What are the key terms and conditions in this agreement?", - "cache_control": {"type": "ephemeral"}, - } - ], - }, - { - "role": "assistant", - "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", - }, - # The final turn is marked with cache-control, for continuing in followups. - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What are the key terms and conditions in this agreement?", - } - ], - }, - ], + messages=gemini_context_caching_messages, temperature=0.2, max_tokens=10, client=client, @@ -2335,3 +2337,20 @@ async def test_gemini_context_caching_anthropic_format(): # assert (response.usage.cache_read_input_tokens > 0) or ( # response.usage.cache_creation_input_tokens > 0 # ) + + check_cache_mock = MagicMock() + client.get = check_cache_mock + try: + response = litellm.completion( + model="gemini/gemini-1.5-flash-001", + messages=gemini_context_caching_messages, + temperature=0.2, + max_tokens=10, + client=client, + ) + + except Exception as e: + print(e) + + check_cache_mock.assert_called_once() + assert mock_client.call_count == 3 diff --git a/litellm/utils.py b/litellm/utils.py index fe7f101250..92b5d5db06 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5070,6 +5070,10 @@ def get_max_tokens(model: str) -> Optional[int]: ) +def _strip_stable_vertex_version(model_name) -> str: + return re.sub(r"-\d+$", "", model_name) + + def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo: """ Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. @@ -5171,9 +5175,15 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod except: pass combined_model_name = model + combined_stripped_model_name = _strip_stable_vertex_version( + model_name=model + ) else: split_model = model combined_model_name = "{}/{}".format(custom_llm_provider, model) + combined_stripped_model_name = "{}/{}".format( + custom_llm_provider, _strip_stable_vertex_version(model_name=model) + ) ######################### supported_openai_params = litellm.get_supported_openai_params( @@ -5200,8 +5210,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod """ Check if: (in order of specificity) 1. 'custom_llm_provider/model' in litellm.model_cost. Checks "groq/llama3-8b-8192" if model="llama3-8b-8192" and custom_llm_provider="groq" - 2. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None - 3. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" + 2. 'combined_stripped_model_name' in litellm.model_cost. Checks if 'gemini/gemini-1.5-flash' in model map, if 'gemini/gemini-1.5-flash-001' given. + 3. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None + 4. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" """ if combined_model_name in litellm.model_cost: key = combined_model_name @@ -5217,6 +5228,26 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception + elif combined_stripped_model_name in litellm.model_cost: + key = model + _model_info = litellm.model_cost[combined_stripped_model_name] + _model_info["supported_openai_params"] = supported_openai_params + if ( + "litellm_provider" in _model_info + and _model_info["litellm_provider"] != custom_llm_provider + ): + if custom_llm_provider == "vertex_ai" and _model_info[ + "litellm_provider" + ].startswith("vertex_ai"): + pass + else: + raise Exception( + "Got provider={}, Expected provider={}, for model={}".format( + _model_info["litellm_provider"], + custom_llm_provider, + model, + ) + ) elif model in litellm.model_cost: key = model _model_info = litellm.model_cost[model] @@ -5320,9 +5351,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod "supports_assistant_prefill", False ), ) - except Exception: + except Exception as e: raise Exception( - "This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format( + "This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format( model, custom_llm_provider ) )