mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(vertex_ai_context_caching.py): check gemini cache, if key already exists
This commit is contained in:
parent
b0cc1df2d6
commit
0eea01dae9
6 changed files with 171 additions and 56 deletions
|
@ -2236,7 +2236,7 @@ class Cache:
|
||||||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||||
self.cache.namespace = self.namespace
|
self.cache.namespace = self.namespace
|
||||||
|
|
||||||
def get_cache_key(self, *args, **kwargs):
|
def get_cache_key(self, *args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Get the cache key for the given arguments.
|
Get the cache key for the given arguments.
|
||||||
|
|
||||||
|
|
|
@ -67,8 +67,7 @@ def separate_cached_messages(
|
||||||
|
|
||||||
|
|
||||||
def transform_openai_messages_to_gemini_context_caching(
|
def transform_openai_messages_to_gemini_context_caching(
|
||||||
model: str,
|
model: str, messages: List[AllMessageValues], cache_key: str
|
||||||
messages: List[AllMessageValues],
|
|
||||||
) -> CachedContentRequestBody:
|
) -> CachedContentRequestBody:
|
||||||
supports_system_message = get_supports_system_message(
|
supports_system_message = get_supports_system_message(
|
||||||
model=model, custom_llm_provider="gemini"
|
model=model, custom_llm_provider="gemini"
|
||||||
|
@ -80,7 +79,9 @@ def transform_openai_messages_to_gemini_context_caching(
|
||||||
|
|
||||||
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
|
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
|
||||||
data = CachedContentRequestBody(
|
data = CachedContentRequestBody(
|
||||||
contents=transformed_messages, model="models/{}".format(model)
|
contents=transformed_messages,
|
||||||
|
model="models/{}".format(model),
|
||||||
|
name="cachedContents/{}".format(cache_key),
|
||||||
)
|
)
|
||||||
if transformed_system_messages is not None:
|
if transformed_system_messages is not None:
|
||||||
data["system_instruction"] = transformed_system_messages
|
data["system_instruction"] = transformed_system_messages
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Callable, List, Literal, Optional, Tuple, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.caching import Cache
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.openai import AllMessageValues
|
from litellm.llms.openai import AllMessageValues
|
||||||
|
@ -19,6 +20,8 @@ from .transformation import (
|
||||||
transform_openai_messages_to_gemini_context_caching,
|
transform_openai_messages_to_gemini_context_caching,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function
|
||||||
|
|
||||||
|
|
||||||
class ContextCachingEndpoints:
|
class ContextCachingEndpoints:
|
||||||
"""
|
"""
|
||||||
|
@ -32,10 +35,10 @@ class ContextCachingEndpoints:
|
||||||
|
|
||||||
def _get_token_and_url(
|
def _get_token_and_url(
|
||||||
self,
|
self,
|
||||||
model: str,
|
|
||||||
gemini_api_key: Optional[str],
|
gemini_api_key: Optional[str],
|
||||||
custom_llm_provider: Literal["gemini"],
|
custom_llm_provider: Literal["gemini"],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
|
cached_key: Optional[str],
|
||||||
) -> Tuple[Optional[str], str]:
|
) -> Tuple[Optional[str], str]:
|
||||||
"""
|
"""
|
||||||
Internal function. Returns the token and url for the call.
|
Internal function. Returns the token and url for the call.
|
||||||
|
@ -46,12 +49,18 @@ class ContextCachingEndpoints:
|
||||||
token, url
|
token, url
|
||||||
"""
|
"""
|
||||||
if custom_llm_provider == "gemini":
|
if custom_llm_provider == "gemini":
|
||||||
_gemini_model_name = "models/{}".format(model)
|
|
||||||
auth_header = None
|
auth_header = None
|
||||||
endpoint = "cachedContents"
|
endpoint = "cachedContents"
|
||||||
url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
|
if cached_key is not None:
|
||||||
|
url = "https://generativelanguage.googleapis.com/v1beta/{}/{}?key={}".format(
|
||||||
|
endpoint, cached_key, gemini_api_key
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
url = (
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
|
||||||
endpoint, gemini_api_key
|
endpoint, gemini_api_key
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -68,7 +77,48 @@ class ContextCachingEndpoints:
|
||||||
|
|
||||||
return auth_header, url
|
return auth_header, url
|
||||||
|
|
||||||
def create_cache(
|
def check_cache(
|
||||||
|
self,
|
||||||
|
cache_key: str,
|
||||||
|
client: HTTPHandler,
|
||||||
|
headers: dict,
|
||||||
|
api_key: str,
|
||||||
|
api_base: Optional[str],
|
||||||
|
logging_obj: Logging,
|
||||||
|
) -> bool:
|
||||||
|
"""Checks if content already cached."""
|
||||||
|
|
||||||
|
_, url = self._get_token_and_url(
|
||||||
|
gemini_api_key=api_key,
|
||||||
|
custom_llm_provider="gemini",
|
||||||
|
api_base=api_base,
|
||||||
|
cached_key=cache_key,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input="",
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": {},
|
||||||
|
"api_base": url,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = client.get(url=url, headers=headers)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return True
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code == 403:
|
||||||
|
return False
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=e.response.status_code, message=e.response.text
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
def check_and_create_cache(
|
||||||
self,
|
self,
|
||||||
messages: List[AllMessageValues], # receives openai format messages
|
messages: List[AllMessageValues], # receives openai format messages
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
@ -95,10 +145,10 @@ class ContextCachingEndpoints:
|
||||||
|
|
||||||
## AUTHORIZATION ##
|
## AUTHORIZATION ##
|
||||||
token, url = self._get_token_and_url(
|
token, url = self._get_token_and_url(
|
||||||
model=model,
|
|
||||||
gemini_api_key=api_key,
|
gemini_api_key=api_key,
|
||||||
custom_llm_provider="gemini",
|
custom_llm_provider="gemini",
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
cached_key=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -126,9 +176,23 @@ class ContextCachingEndpoints:
|
||||||
if len(cached_messages) == 0:
|
if len(cached_messages) == 0:
|
||||||
return messages, None
|
return messages, None
|
||||||
|
|
||||||
|
## CHECK IF CACHED ALREADY
|
||||||
|
generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
|
||||||
|
cache_exists = self.check_cache(
|
||||||
|
cache_key=generated_cache_key,
|
||||||
|
client=client,
|
||||||
|
headers=headers,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
if cache_exists:
|
||||||
|
return non_cached_messages, generated_cache_key
|
||||||
|
|
||||||
|
## TRANSFORM REQUEST
|
||||||
cached_content_request_body = (
|
cached_content_request_body = (
|
||||||
transform_openai_messages_to_gemini_context_caching(
|
transform_openai_messages_to_gemini_context_caching(
|
||||||
model=model, messages=cached_messages
|
model=model, messages=cached_messages, cache_key=generated_cache_key
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1362,7 +1362,7 @@ class VertexLLM(BaseLLM):
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
### CHECK CONTEXT CACHING ###
|
### CHECK CONTEXT CACHING ###
|
||||||
if gemini_api_key is not None:
|
if gemini_api_key is not None:
|
||||||
messages, cached_content = context_caching_endpoints.create_cache(
|
messages, cached_content = context_caching_endpoints.check_and_create_cache(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_key=gemini_api_key,
|
api_key=gemini_api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
|
|
@ -2263,25 +2263,14 @@ def mock_gemini_request(*args, **kwargs):
|
||||||
return mock_response
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
gemini_context_caching_messages = [
|
||||||
async def test_gemini_context_caching_anthropic_format():
|
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
|
||||||
|
|
||||||
litellm.set_verbose = True
|
|
||||||
client = HTTPHandler(concurrent_limit=1)
|
|
||||||
with patch.object(client, "post", side_effect=mock_gemini_request) as mock_client:
|
|
||||||
try:
|
|
||||||
response = litellm.completion(
|
|
||||||
model="gemini/gemini-1.5-flash-001",
|
|
||||||
messages=[
|
|
||||||
# System Message
|
# System Message
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Here is the full text of a complex legal agreement"
|
"text": "Here is the full text of a complex legal agreement" * 4000,
|
||||||
* 4000,
|
|
||||||
"cache_control": {"type": "ephemeral"},
|
"cache_control": {"type": "ephemeral"},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -2311,7 +2300,20 @@ async def test_gemini_context_caching_anthropic_format():
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_context_caching_anthropic_format():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
with patch.object(client, "post", side_effect=mock_gemini_request) as mock_client:
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gemini/gemini-1.5-flash-001",
|
||||||
|
messages=gemini_context_caching_messages,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -2335,3 +2337,20 @@ async def test_gemini_context_caching_anthropic_format():
|
||||||
# assert (response.usage.cache_read_input_tokens > 0) or (
|
# assert (response.usage.cache_read_input_tokens > 0) or (
|
||||||
# response.usage.cache_creation_input_tokens > 0
|
# response.usage.cache_creation_input_tokens > 0
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
check_cache_mock = MagicMock()
|
||||||
|
client.get = check_cache_mock
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gemini/gemini-1.5-flash-001",
|
||||||
|
messages=gemini_context_caching_messages,
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=10,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
check_cache_mock.assert_called_once()
|
||||||
|
assert mock_client.call_count == 3
|
||||||
|
|
|
@ -5070,6 +5070,10 @@ def get_max_tokens(model: str) -> Optional[int]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_stable_vertex_version(model_name) -> str:
|
||||||
|
return re.sub(r"-\d+$", "", model_name)
|
||||||
|
|
||||||
|
|
||||||
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
|
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model.
|
Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model.
|
||||||
|
@ -5171,9 +5175,15 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
combined_model_name = model
|
combined_model_name = model
|
||||||
|
combined_stripped_model_name = _strip_stable_vertex_version(
|
||||||
|
model_name=model
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
split_model = model
|
split_model = model
|
||||||
combined_model_name = "{}/{}".format(custom_llm_provider, model)
|
combined_model_name = "{}/{}".format(custom_llm_provider, model)
|
||||||
|
combined_stripped_model_name = "{}/{}".format(
|
||||||
|
custom_llm_provider, _strip_stable_vertex_version(model_name=model)
|
||||||
|
)
|
||||||
#########################
|
#########################
|
||||||
|
|
||||||
supported_openai_params = litellm.get_supported_openai_params(
|
supported_openai_params = litellm.get_supported_openai_params(
|
||||||
|
@ -5200,8 +5210,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
"""
|
"""
|
||||||
Check if: (in order of specificity)
|
Check if: (in order of specificity)
|
||||||
1. 'custom_llm_provider/model' in litellm.model_cost. Checks "groq/llama3-8b-8192" if model="llama3-8b-8192" and custom_llm_provider="groq"
|
1. 'custom_llm_provider/model' in litellm.model_cost. Checks "groq/llama3-8b-8192" if model="llama3-8b-8192" and custom_llm_provider="groq"
|
||||||
2. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None
|
2. 'combined_stripped_model_name' in litellm.model_cost. Checks if 'gemini/gemini-1.5-flash' in model map, if 'gemini/gemini-1.5-flash-001' given.
|
||||||
3. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
|
3. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None
|
||||||
|
4. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
|
||||||
"""
|
"""
|
||||||
if combined_model_name in litellm.model_cost:
|
if combined_model_name in litellm.model_cost:
|
||||||
key = combined_model_name
|
key = combined_model_name
|
||||||
|
@ -5217,6 +5228,26 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
elif combined_stripped_model_name in litellm.model_cost:
|
||||||
|
key = model
|
||||||
|
_model_info = litellm.model_cost[combined_stripped_model_name]
|
||||||
|
_model_info["supported_openai_params"] = supported_openai_params
|
||||||
|
if (
|
||||||
|
"litellm_provider" in _model_info
|
||||||
|
and _model_info["litellm_provider"] != custom_llm_provider
|
||||||
|
):
|
||||||
|
if custom_llm_provider == "vertex_ai" and _model_info[
|
||||||
|
"litellm_provider"
|
||||||
|
].startswith("vertex_ai"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Got provider={}, Expected provider={}, for model={}".format(
|
||||||
|
_model_info["litellm_provider"],
|
||||||
|
custom_llm_provider,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
)
|
||||||
elif model in litellm.model_cost:
|
elif model in litellm.model_cost:
|
||||||
key = model
|
key = model
|
||||||
_model_info = litellm.model_cost[model]
|
_model_info = litellm.model_cost[model]
|
||||||
|
@ -5320,9 +5351,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
"supports_assistant_prefill", False
|
"supports_assistant_prefill", False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format(
|
||||||
model, custom_llm_provider
|
model, custom_llm_provider
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue