feat(vertex_ai_context_caching.py): check gemini cache, if key already exists

This commit is contained in:
Krrish Dholakia 2024-08-26 20:28:18 -07:00
parent 074e30fa10
commit b277086cf7
6 changed files with 171 additions and 56 deletions

View file

@ -2236,7 +2236,7 @@ class Cache:
if self.namespace is not None and isinstance(self.cache, RedisCache): if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace self.cache.namespace = self.namespace
def get_cache_key(self, *args, **kwargs): def get_cache_key(self, *args, **kwargs) -> str:
""" """
Get the cache key for the given arguments. Get the cache key for the given arguments.

View file

@ -67,8 +67,7 @@ def separate_cached_messages(
def transform_openai_messages_to_gemini_context_caching( def transform_openai_messages_to_gemini_context_caching(
model: str, model: str, messages: List[AllMessageValues], cache_key: str
messages: List[AllMessageValues],
) -> CachedContentRequestBody: ) -> CachedContentRequestBody:
supports_system_message = get_supports_system_message( supports_system_message = get_supports_system_message(
model=model, custom_llm_provider="gemini" model=model, custom_llm_provider="gemini"
@ -80,7 +79,9 @@ def transform_openai_messages_to_gemini_context_caching(
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages) transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
data = CachedContentRequestBody( data = CachedContentRequestBody(
contents=transformed_messages, model="models/{}".format(model) contents=transformed_messages,
model="models/{}".format(model),
name="cachedContents/{}".format(cache_key),
) )
if transformed_system_messages is not None: if transformed_system_messages is not None:
data["system_instruction"] = transformed_system_messages data["system_instruction"] = transformed_system_messages

View file

@ -4,6 +4,7 @@ from typing import Callable, List, Literal, Optional, Tuple, Union
import httpx import httpx
import litellm import litellm
from litellm.caching import Cache
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.openai import AllMessageValues from litellm.llms.openai import AllMessageValues
@ -19,6 +20,8 @@ from .transformation import (
transform_openai_messages_to_gemini_context_caching, transform_openai_messages_to_gemini_context_caching,
) )
local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function
class ContextCachingEndpoints: class ContextCachingEndpoints:
""" """
@ -32,10 +35,10 @@ class ContextCachingEndpoints:
def _get_token_and_url( def _get_token_and_url(
self, self,
model: str,
gemini_api_key: Optional[str], gemini_api_key: Optional[str],
custom_llm_provider: Literal["gemini"], custom_llm_provider: Literal["gemini"],
api_base: Optional[str], api_base: Optional[str],
cached_key: Optional[str],
) -> Tuple[Optional[str], str]: ) -> Tuple[Optional[str], str]:
""" """
Internal function. Returns the token and url for the call. Internal function. Returns the token and url for the call.
@ -46,12 +49,18 @@ class ContextCachingEndpoints:
token, url token, url
""" """
if custom_llm_provider == "gemini": if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None auth_header = None
endpoint = "cachedContents" endpoint = "cachedContents"
url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format( if cached_key is not None:
endpoint, gemini_api_key url = "https://generativelanguage.googleapis.com/v1beta/{}/{}?key={}".format(
) endpoint, cached_key, gemini_api_key
)
else:
url = (
"https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
endpoint, gemini_api_key
)
)
else: else:
raise NotImplementedError raise NotImplementedError
@ -68,7 +77,48 @@ class ContextCachingEndpoints:
return auth_header, url return auth_header, url
def create_cache( def check_cache(
self,
cache_key: str,
client: HTTPHandler,
headers: dict,
api_key: str,
api_base: Optional[str],
logging_obj: Logging,
) -> bool:
"""Checks if content already cached."""
_, url = self._get_token_and_url(
gemini_api_key=api_key,
custom_llm_provider="gemini",
api_base=api_base,
cached_key=cache_key,
)
try:
## LOGGING
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": {},
"api_base": url,
"headers": headers,
},
)
resp = client.get(url=url, headers=headers)
resp.raise_for_status()
return True
except httpx.HTTPStatusError as e:
if e.response.status_code == 403:
return False
raise VertexAIError(
status_code=e.response.status_code, message=e.response.text
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
def check_and_create_cache(
self, self,
messages: List[AllMessageValues], # receives openai format messages messages: List[AllMessageValues], # receives openai format messages
api_key: str, api_key: str,
@ -95,10 +145,10 @@ class ContextCachingEndpoints:
## AUTHORIZATION ## ## AUTHORIZATION ##
token, url = self._get_token_and_url( token, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key, gemini_api_key=api_key,
custom_llm_provider="gemini", custom_llm_provider="gemini",
api_base=api_base, api_base=api_base,
cached_key=None,
) )
headers = { headers = {
@ -126,9 +176,23 @@ class ContextCachingEndpoints:
if len(cached_messages) == 0: if len(cached_messages) == 0:
return messages, None return messages, None
## CHECK IF CACHED ALREADY
generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
cache_exists = self.check_cache(
cache_key=generated_cache_key,
client=client,
headers=headers,
api_key=api_key,
api_base=api_base,
logging_obj=logging_obj,
)
if cache_exists:
return non_cached_messages, generated_cache_key
## TRANSFORM REQUEST
cached_content_request_body = ( cached_content_request_body = (
transform_openai_messages_to_gemini_context_caching( transform_openai_messages_to_gemini_context_caching(
model=model, messages=cached_messages model=model, messages=cached_messages, cache_key=generated_cache_key
) )
) )

View file

@ -1362,7 +1362,7 @@ class VertexLLM(BaseLLM):
## TRANSFORMATION ## ## TRANSFORMATION ##
### CHECK CONTEXT CACHING ### ### CHECK CONTEXT CACHING ###
if gemini_api_key is not None: if gemini_api_key is not None:
messages, cached_content = context_caching_endpoints.create_cache( messages, cached_content = context_caching_endpoints.check_and_create_cache(
messages=messages, messages=messages,
api_key=gemini_api_key, api_key=gemini_api_key,
api_base=api_base, api_base=api_base,

View file

@ -2263,6 +2263,46 @@ def mock_gemini_request(*args, **kwargs):
return mock_response return mock_response
gemini_context_caching_messages = [
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
}
],
},
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_context_caching_anthropic_format(): async def test_gemini_context_caching_anthropic_format():
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
@ -2273,45 +2313,7 @@ async def test_gemini_context_caching_anthropic_format():
try: try:
response = litellm.completion( response = litellm.completion(
model="gemini/gemini-1.5-flash-001", model="gemini/gemini-1.5-flash-001",
messages=[ messages=gemini_context_caching_messages,
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 4000,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
}
],
},
],
temperature=0.2, temperature=0.2,
max_tokens=10, max_tokens=10,
client=client, client=client,
@ -2335,3 +2337,20 @@ async def test_gemini_context_caching_anthropic_format():
# assert (response.usage.cache_read_input_tokens > 0) or ( # assert (response.usage.cache_read_input_tokens > 0) or (
# response.usage.cache_creation_input_tokens > 0 # response.usage.cache_creation_input_tokens > 0
# ) # )
check_cache_mock = MagicMock()
client.get = check_cache_mock
try:
response = litellm.completion(
model="gemini/gemini-1.5-flash-001",
messages=gemini_context_caching_messages,
temperature=0.2,
max_tokens=10,
client=client,
)
except Exception as e:
print(e)
check_cache_mock.assert_called_once()
assert mock_client.call_count == 3

View file

@ -5070,6 +5070,10 @@ def get_max_tokens(model: str) -> Optional[int]:
) )
def _strip_stable_vertex_version(model_name) -> str:
return re.sub(r"-\d+$", "", model_name)
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo: def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
""" """
Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model.
@ -5171,9 +5175,15 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
except: except:
pass pass
combined_model_name = model combined_model_name = model
combined_stripped_model_name = _strip_stable_vertex_version(
model_name=model
)
else: else:
split_model = model split_model = model
combined_model_name = "{}/{}".format(custom_llm_provider, model) combined_model_name = "{}/{}".format(custom_llm_provider, model)
combined_stripped_model_name = "{}/{}".format(
custom_llm_provider, _strip_stable_vertex_version(model_name=model)
)
######################### #########################
supported_openai_params = litellm.get_supported_openai_params( supported_openai_params = litellm.get_supported_openai_params(
@ -5200,8 +5210,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
""" """
Check if: (in order of specificity) Check if: (in order of specificity)
1. 'custom_llm_provider/model' in litellm.model_cost. Checks "groq/llama3-8b-8192" if model="llama3-8b-8192" and custom_llm_provider="groq" 1. 'custom_llm_provider/model' in litellm.model_cost. Checks "groq/llama3-8b-8192" if model="llama3-8b-8192" and custom_llm_provider="groq"
2. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None 2. 'combined_stripped_model_name' in litellm.model_cost. Checks if 'gemini/gemini-1.5-flash' in model map, if 'gemini/gemini-1.5-flash-001' given.
3. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" 3. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None
4. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
""" """
if combined_model_name in litellm.model_cost: if combined_model_name in litellm.model_cost:
key = combined_model_name key = combined_model_name
@ -5217,6 +5228,26 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass pass
else: else:
raise Exception raise Exception
elif combined_stripped_model_name in litellm.model_cost:
key = model
_model_info = litellm.model_cost[combined_stripped_model_name]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
raise Exception(
"Got provider={}, Expected provider={}, for model={}".format(
_model_info["litellm_provider"],
custom_llm_provider,
model,
)
)
elif model in litellm.model_cost: elif model in litellm.model_cost:
key = model key = model
_model_info = litellm.model_cost[model] _model_info = litellm.model_cost[model]
@ -5320,9 +5351,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
"supports_assistant_prefill", False "supports_assistant_prefill", False
), ),
) )
except Exception: except Exception as e:
raise Exception( raise Exception(
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format( "This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format(
model, custom_llm_provider model, custom_llm_provider
) )
) )