mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Fix #8296 - Modify gemini cache configuration to also move tools to the cache
This commit is contained in:
parent
3c463f6715
commit
7d5e7b6c13
5 changed files with 398 additions and 64 deletions
|
@ -1,13 +1,17 @@
|
||||||
"""
|
"""
|
||||||
Transformation logic for context caching.
|
Transformation logic for context caching.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Tuple
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from litellm.caching.caching import Cache
|
||||||
|
|
||||||
|
from litellm.types.caching import LiteLLMCacheType
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.llms.vertex_ai import CachedContentRequestBody
|
from litellm.types.llms.vertex_ai import CachedContentRequestBody, ToolConfig, Tools
|
||||||
from litellm.utils import is_cached_message
|
from litellm.utils import is_cached_message
|
||||||
|
|
||||||
from ..common_utils import get_supports_system_message
|
from ..common_utils import get_supports_system_message
|
||||||
|
@ -88,7 +92,11 @@ def separate_cached_messages(
|
||||||
|
|
||||||
|
|
||||||
def transform_openai_messages_to_gemini_context_caching(
|
def transform_openai_messages_to_gemini_context_caching(
|
||||||
model: str, messages: List[AllMessageValues], cache_key: str
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
cache_key: str,
|
||||||
|
tools: Optional[List[Tools]] = None,
|
||||||
|
tool_choice: Optional[ToolConfig] = None,
|
||||||
) -> CachedContentRequestBody:
|
) -> CachedContentRequestBody:
|
||||||
supports_system_message = get_supports_system_message(
|
supports_system_message = get_supports_system_message(
|
||||||
model=model, custom_llm_provider="gemini"
|
model=model, custom_llm_provider="gemini"
|
||||||
|
@ -103,8 +111,126 @@ def transform_openai_messages_to_gemini_context_caching(
|
||||||
contents=transformed_messages,
|
contents=transformed_messages,
|
||||||
model="models/{}".format(model),
|
model="models/{}".format(model),
|
||||||
displayName=cache_key,
|
displayName=cache_key,
|
||||||
|
tools=tools,
|
||||||
|
toolConfig=tool_choice,
|
||||||
)
|
)
|
||||||
if transformed_system_messages is not None:
|
if transformed_system_messages is not None:
|
||||||
data["system_instruction"] = transformed_system_messages
|
data["system_instruction"] = transformed_system_messages
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
local_cache_obj = Cache(type=LiteLLMCacheType.LOCAL)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CacheSplitResult:
|
||||||
|
"""Result of splitting messages into cacheable and non-cacheable parts"""
|
||||||
|
|
||||||
|
remaining_messages: List[
|
||||||
|
AllMessageValues
|
||||||
|
] # Messages that should be sent in actual request
|
||||||
|
optional_params: Dict[str, Any] # Updated params to be sent in actual request
|
||||||
|
cache_key: Optional[str] # Key to use for checking if content is already cached
|
||||||
|
cached_content: Optional[
|
||||||
|
str
|
||||||
|
] # cached content ID, no further processing is needed once this is defined
|
||||||
|
cache_request_body: Optional[
|
||||||
|
CachedContentRequestBody
|
||||||
|
] # Request body to create new cache if needed
|
||||||
|
|
||||||
|
def with_cached_content(self, cached_content: str) -> "CacheSplitResult":
|
||||||
|
"""
|
||||||
|
Returns an updated CacheSplitResult with the cached content applied.
|
||||||
|
"""
|
||||||
|
updated_params = {**self.optional_params, "cached_content": cached_content}
|
||||||
|
return replace(
|
||||||
|
self,
|
||||||
|
cached_content=cached_content,
|
||||||
|
optional_params=updated_params,
|
||||||
|
cache_request_body=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_cache_configuration(
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict[str, Any],
|
||||||
|
) -> CacheSplitResult:
|
||||||
|
"""
|
||||||
|
Checks if a given request should have a cache, and if so, extracts the cache configuration, returning
|
||||||
|
a modified version of the messages and optional params.
|
||||||
|
|
||||||
|
- Removes the cached content from the messages
|
||||||
|
- Adds the cache key to the optional params
|
||||||
|
- If there's cached content, also moves the tool call and tool choice to the optional params, as that is
|
||||||
|
required for the cache to work. (The tools are moved into some sort of system prompt on google's side)
|
||||||
|
|
||||||
|
Relevant error:
|
||||||
|
"error": {
|
||||||
|
"code": 400,
|
||||||
|
"message": "CachedContent can not be used with GenerateContent request setting system_instruction, tools or tool_config.\n\nProposed fix: move those values to CachedContent from GenerateContent request.",
|
||||||
|
"status": "INVALID_ARGUMENT"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CacheSplitResult with:
|
||||||
|
- remaining_messages: Messages that should be sent in the actual request
|
||||||
|
- cache_key: The key to use for checking if content is already cached
|
||||||
|
- cached_content: The cached content ID if already provided
|
||||||
|
- cache_request_body: The request body to create a new cache entry if needed
|
||||||
|
"""
|
||||||
|
# If cached content is already provided, no need to process messages
|
||||||
|
if (
|
||||||
|
"cached_content" in optional_params
|
||||||
|
and optional_params["cached_content"] is not None
|
||||||
|
):
|
||||||
|
return CacheSplitResult(
|
||||||
|
remaining_messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
cache_key=None,
|
||||||
|
cached_content=optional_params["cached_content"],
|
||||||
|
cache_request_body=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Separate messages that can be cached from those that can't
|
||||||
|
cached_messages, non_cached_messages = separate_cached_messages(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
|
# If no messages can be cached, return original messages
|
||||||
|
if len(cached_messages) == 0:
|
||||||
|
return CacheSplitResult(
|
||||||
|
remaining_messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
cache_key=None,
|
||||||
|
cached_content=None,
|
||||||
|
cache_request_body=None,
|
||||||
|
)
|
||||||
|
if "tools" in optional_params or "tool_choice" in optional_params:
|
||||||
|
optional_params = optional_params.copy()
|
||||||
|
tools = optional_params.pop("tools", None)
|
||||||
|
tool_choice = optional_params.pop("tool_choice", None)
|
||||||
|
else:
|
||||||
|
tools = None
|
||||||
|
tool_choice = None
|
||||||
|
key_kwargs = {}
|
||||||
|
if tools is not None:
|
||||||
|
key_kwargs["tools"] = tools
|
||||||
|
if tool_choice is not None:
|
||||||
|
key_kwargs["tool_choice"] = tool_choice
|
||||||
|
|
||||||
|
# Generate cache key for the cacheable messages
|
||||||
|
cache_key = local_cache_obj.get_cache_key(messages=cached_messages, **key_kwargs)
|
||||||
|
|
||||||
|
# Transform cached messages into request body
|
||||||
|
cache_request_body = transform_openai_messages_to_gemini_context_caching(
|
||||||
|
model=model, messages=cached_messages, cache_key=cache_key, tools=tools, tool_choice=tool_choice
|
||||||
|
)
|
||||||
|
|
||||||
|
return CacheSplitResult(
|
||||||
|
remaining_messages=non_cached_messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
cache_key=cache_key,
|
||||||
|
cached_content=None,
|
||||||
|
cache_request_body=cache_request_body,
|
||||||
|
)
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
from typing import List, Literal, Optional, Tuple, Union
|
from typing import List, Literal, Optional, Tuple, Union, Dict, Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.caching.caching import Cache, LiteLLMCacheType
|
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
|
@ -19,15 +18,10 @@ from litellm.types.llms.vertex_ai import (
|
||||||
from ..common_utils import VertexAIError
|
from ..common_utils import VertexAIError
|
||||||
from ..vertex_llm_base import VertexBase
|
from ..vertex_llm_base import VertexBase
|
||||||
from .transformation import (
|
from .transformation import (
|
||||||
separate_cached_messages,
|
CacheSplitResult,
|
||||||
transform_openai_messages_to_gemini_context_caching,
|
extract_cache_configuration,
|
||||||
)
|
)
|
||||||
|
|
||||||
local_cache_obj = Cache(
|
|
||||||
type=LiteLLMCacheType.LOCAL
|
|
||||||
) # only used for calling 'get_cache_key' function
|
|
||||||
|
|
||||||
|
|
||||||
class ContextCachingEndpoints(VertexBase):
|
class ContextCachingEndpoints(VertexBase):
|
||||||
"""
|
"""
|
||||||
Covers context caching endpoints for Vertex AI + Google AI Studio
|
Covers context caching endpoints for Vertex AI + Google AI Studio
|
||||||
|
@ -205,6 +199,7 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
def check_and_create_cache(
|
def check_and_create_cache(
|
||||||
self,
|
self,
|
||||||
messages: List[AllMessageValues], # receives openai format messages
|
messages: List[AllMessageValues], # receives openai format messages
|
||||||
|
optional_params: Dict[str, Any],
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -212,8 +207,7 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
logging_obj: Logging,
|
logging_obj: Logging,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
cached_content: Optional[str] = None,
|
) -> CacheSplitResult:
|
||||||
) -> Tuple[List[AllMessageValues], Optional[str]]:
|
|
||||||
"""
|
"""
|
||||||
Receives
|
Receives
|
||||||
- messages: List of dict - messages in the openai format
|
- messages: List of dict - messages in the openai format
|
||||||
|
@ -224,8 +218,19 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
|
|
||||||
Follows - https://ai.google.dev/api/caching#request-body
|
Follows - https://ai.google.dev/api/caching#request-body
|
||||||
"""
|
"""
|
||||||
if cached_content is not None:
|
|
||||||
return messages, cached_content
|
cache_split_result = extract_cache_configuration(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cache_split_result.cache_request_body is None
|
||||||
|
or cache_split_result.cached_content is not None
|
||||||
|
or cache_split_result.cache_key is None
|
||||||
|
):
|
||||||
|
return cache_split_result
|
||||||
|
|
||||||
## AUTHORIZATION ##
|
## AUTHORIZATION ##
|
||||||
token, url = self._get_token_and_url_context_caching(
|
token, url = self._get_token_and_url_context_caching(
|
||||||
|
@ -252,17 +257,9 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
else:
|
else:
|
||||||
client = client
|
client = client
|
||||||
|
|
||||||
cached_messages, non_cached_messages = separate_cached_messages(
|
|
||||||
messages=messages
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(cached_messages) == 0:
|
|
||||||
return messages, None
|
|
||||||
|
|
||||||
## CHECK IF CACHED ALREADY
|
## CHECK IF CACHED ALREADY
|
||||||
generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
|
|
||||||
google_cache_name = self.check_cache(
|
google_cache_name = self.check_cache(
|
||||||
cache_key=generated_cache_key,
|
cache_key=cache_split_result.cache_key,
|
||||||
client=client,
|
client=client,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -270,21 +267,16 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
if google_cache_name:
|
if google_cache_name:
|
||||||
return non_cached_messages, google_cache_name
|
return cache_split_result.with_cached_content(cached_content=google_cache_name)
|
||||||
|
|
||||||
## TRANSFORM REQUEST
|
## TRANSFORM REQUEST
|
||||||
cached_content_request_body = (
|
|
||||||
transform_openai_messages_to_gemini_context_caching(
|
|
||||||
model=model, messages=cached_messages, cache_key=generated_cache_key
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": cached_content_request_body,
|
"complete_input_dict": cache_split_result.cache_request_body,
|
||||||
"api_base": url,
|
"api_base": url,
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
},
|
},
|
||||||
|
@ -292,7 +284,7 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = client.post(
|
response = client.post(
|
||||||
url=url, headers=headers, json=cached_content_request_body # type: ignore
|
url=url, headers=headers, json=cache_split_result.cache_request_body # type: ignore
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
|
@ -305,11 +297,12 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
cached_content_response_obj = VertexAICachedContentResponseObject(
|
cached_content_response_obj = VertexAICachedContentResponseObject(
|
||||||
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
||||||
)
|
)
|
||||||
return (non_cached_messages, cached_content_response_obj["name"])
|
return cache_split_result.with_cached_content(cached_content=cached_content_response_obj["name"])
|
||||||
|
|
||||||
async def async_check_and_create_cache(
|
async def async_check_and_create_cache(
|
||||||
self,
|
self,
|
||||||
messages: List[AllMessageValues], # receives openai format messages
|
messages: List[AllMessageValues], # receives openai format messages
|
||||||
|
optional_params: Dict[str, Any],
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -317,8 +310,7 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
logging_obj: Logging,
|
logging_obj: Logging,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
cached_content: Optional[str] = None,
|
) -> CacheSplitResult:
|
||||||
) -> Tuple[List[AllMessageValues], Optional[str]]:
|
|
||||||
"""
|
"""
|
||||||
Receives
|
Receives
|
||||||
- messages: List of dict - messages in the openai format
|
- messages: List of dict - messages in the openai format
|
||||||
|
@ -329,15 +321,19 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
|
|
||||||
Follows - https://ai.google.dev/api/caching#request-body
|
Follows - https://ai.google.dev/api/caching#request-body
|
||||||
"""
|
"""
|
||||||
if cached_content is not None:
|
|
||||||
return messages, cached_content
|
|
||||||
|
|
||||||
cached_messages, non_cached_messages = separate_cached_messages(
|
cache_split_result = extract_cache_configuration(
|
||||||
messages=messages
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(cached_messages) == 0:
|
if (
|
||||||
return messages, None
|
cache_split_result.cache_request_body is None
|
||||||
|
or cache_split_result.cached_content is not None
|
||||||
|
or cache_split_result.cache_key is None
|
||||||
|
):
|
||||||
|
return cache_split_result
|
||||||
|
|
||||||
## AUTHORIZATION ##
|
## AUTHORIZATION ##
|
||||||
token, url = self._get_token_and_url_context_caching(
|
token, url = self._get_token_and_url_context_caching(
|
||||||
|
@ -362,9 +358,8 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
client = client
|
client = client
|
||||||
|
|
||||||
## CHECK IF CACHED ALREADY
|
## CHECK IF CACHED ALREADY
|
||||||
generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
|
|
||||||
google_cache_name = await self.async_check_cache(
|
google_cache_name = await self.async_check_cache(
|
||||||
cache_key=generated_cache_key,
|
cache_key=cache_split_result.cache_key,
|
||||||
client=client,
|
client=client,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -372,21 +367,16 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
if google_cache_name:
|
if google_cache_name:
|
||||||
return non_cached_messages, google_cache_name
|
return cache_split_result.with_cached_content(cached_content=google_cache_name)
|
||||||
|
|
||||||
## TRANSFORM REQUEST
|
## TRANSFORM REQUEST
|
||||||
cached_content_request_body = (
|
|
||||||
transform_openai_messages_to_gemini_context_caching(
|
|
||||||
model=model, messages=cached_messages, cache_key=generated_cache_key
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": cached_content_request_body,
|
"complete_input_dict": cache_split_result.cache_request_body,
|
||||||
"api_base": url,
|
"api_base": url,
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
},
|
},
|
||||||
|
@ -394,7 +384,7 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
url=url, headers=headers, json=cached_content_request_body # type: ignore
|
url=url, headers=headers, json=cache_split_result.cache_request_body # type: ignore
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
|
@ -407,7 +397,7 @@ class ContextCachingEndpoints(VertexBase):
|
||||||
cached_content_response_obj = VertexAICachedContentResponseObject(
|
cached_content_response_obj = VertexAICachedContentResponseObject(
|
||||||
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
||||||
)
|
)
|
||||||
return (non_cached_messages, cached_content_response_obj["name"])
|
return cache_split_result.with_cached_content(cached_content=cached_content_response_obj["name"])
|
||||||
|
|
||||||
def get_cache(self):
|
def get_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Transformation logic from OpenAI format to Gemini format.
|
Transformation logic from OpenAI format to Gemini format.
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
"""
|
"""
|
||||||
|
@ -431,17 +431,20 @@ def sync_transform_request_body(
|
||||||
context_caching_endpoints = ContextCachingEndpoints()
|
context_caching_endpoints = ContextCachingEndpoints()
|
||||||
|
|
||||||
if gemini_api_key is not None:
|
if gemini_api_key is not None:
|
||||||
messages, cached_content = context_caching_endpoints.check_and_create_cache(
|
cache_split_result = context_caching_endpoints.check_and_create_cache(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
api_key=gemini_api_key,
|
api_key=gemini_api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model=model,
|
model=model,
|
||||||
client=client,
|
client=client,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
cached_content=optional_params.pop("cached_content", None),
|
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
messages = cache_split_result.remaining_messages
|
||||||
|
cached_content = cache_split_result.cached_content
|
||||||
|
optional_params = cache_split_result.optional_params
|
||||||
else: # [TODO] implement context caching for gemini as well
|
else: # [TODO] implement context caching for gemini as well
|
||||||
cached_content = optional_params.pop("cached_content", None)
|
cached_content = optional_params.pop("cached_content", None)
|
||||||
|
|
||||||
|
@ -473,20 +476,20 @@ async def async_transform_request_body(
|
||||||
context_caching_endpoints = ContextCachingEndpoints()
|
context_caching_endpoints = ContextCachingEndpoints()
|
||||||
|
|
||||||
if gemini_api_key is not None:
|
if gemini_api_key is not None:
|
||||||
(
|
cache_split_result = await context_caching_endpoints.async_check_and_create_cache(
|
||||||
messages,
|
|
||||||
cached_content,
|
|
||||||
) = await context_caching_endpoints.async_check_and_create_cache(
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
api_key=gemini_api_key,
|
api_key=gemini_api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model=model,
|
model=model,
|
||||||
client=client,
|
client=client,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
cached_content=optional_params.pop("cached_content", None),
|
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
messages = cache_split_result.remaining_messages
|
||||||
|
cached_content = cache_split_result.cached_content
|
||||||
|
optional_params = cache_split_result.optional_params
|
||||||
else: # [TODO] implement context caching for gemini as well
|
else: # [TODO] implement context caching for gemini as well
|
||||||
cached_content = optional_params.pop("cached_content", None)
|
cached_content = optional_params.pop("cached_content", None)
|
||||||
|
|
||||||
|
|
|
@ -251,8 +251,8 @@ class RequestBody(TypedDict, total=False):
|
||||||
class CachedContentRequestBody(TypedDict, total=False):
|
class CachedContentRequestBody(TypedDict, total=False):
|
||||||
contents: Required[List[ContentType]]
|
contents: Required[List[ContentType]]
|
||||||
system_instruction: SystemInstructions
|
system_instruction: SystemInstructions
|
||||||
tools: Tools
|
tools: Optional[List[Tools]]
|
||||||
toolConfig: ToolConfig
|
toolConfig: Optional[ToolConfig]
|
||||||
model: Required[str] # Format: models/{model}
|
model: Required[str] # Format: models/{model}
|
||||||
ttl: str # ending in 's' - Example: "3.5s".
|
ttl: str # ending in 's' - Example: "3.5s".
|
||||||
displayName: str
|
displayName: str
|
||||||
|
|
|
@ -1,14 +1,21 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, cast
|
from typing import List, cast
|
||||||
|
import json
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import respx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import ModelResponse
|
from litellm import ModelResponse
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
|
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||||
|
from litellm.llms.gemini.chat.transformation import GoogleAIStudioGeminiConfig
|
||||||
|
from litellm.llms.vertex_ai.gemini.transformation import async_transform_request_body, sync_transform_request_body
|
||||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexGeminiConfig,
|
VertexGeminiConfig,
|
||||||
|
VertexLLM,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import ChoiceLogprobs
|
from litellm.types.utils import ChoiceLogprobs
|
||||||
|
|
||||||
|
@ -239,3 +246,211 @@ def test_vertex_ai_thinking_output_part():
|
||||||
content, reasoning_content = v.get_assistant_content_message(parts=parts)
|
content, reasoning_content = v.get_assistant_content_message(parts=parts)
|
||||||
assert content == "Hello world"
|
assert content == "Hello world"
|
||||||
assert reasoning_content == "I'm thinking..."
|
assert reasoning_content == "I'm thinking..."
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_logging():
|
||||||
|
mock_logging = MagicMock(spec=LiteLLMLoggingObj)
|
||||||
|
mock_logging.pre_call = MagicMock(return_value=None)
|
||||||
|
mock_logging.post_call = MagicMock(return_value=None)
|
||||||
|
return mock_logging
|
||||||
|
|
||||||
|
def _mock_get_post_cached_content(api_key: str, respx_mock: respx.MockRouter) -> tuple[respx.MockRouter, respx.MockRouter]:
|
||||||
|
get_mock = respx_mock.get(
|
||||||
|
f"https://generativelanguage.googleapis.com/v1beta/cachedContents?key={api_key}"
|
||||||
|
).respond(
|
||||||
|
json={
|
||||||
|
"cachedContents": [],
|
||||||
|
"nextPageToken": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
post_mock = respx_mock.post(
|
||||||
|
f"https://generativelanguage.googleapis.com/v1beta/cachedContents?key={api_key}"
|
||||||
|
).respond(
|
||||||
|
json={
|
||||||
|
"name": "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id",
|
||||||
|
"model": "gemini-2.0-flash-001",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return get_mock, post_mock
|
||||||
|
|
||||||
|
def test_google_ai_studio_gemini_message_caching_sync(
|
||||||
|
# ideally this would unit test just a small transformation, but there's a lot going on with gemini/vertex
|
||||||
|
# (hinges around branching for sync/async transformations).
|
||||||
|
respx_mock: respx.MockRouter,
|
||||||
|
):
|
||||||
|
mock_logging = _mock_logging()
|
||||||
|
|
||||||
|
get_mock, post_mock = _mock_get_post_cached_content("fake_api_key", respx_mock)
|
||||||
|
|
||||||
|
transformed_request = sync_transform_request_body(
|
||||||
|
gemini_api_key="fake_api_key",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "you are a helpful assistant",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, world!",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
api_base=None,
|
||||||
|
model="gemini-2.0-flash-001",
|
||||||
|
client=None,
|
||||||
|
timeout=None,
|
||||||
|
extra_headers=None,
|
||||||
|
optional_params={},
|
||||||
|
logging_obj=mock_logging,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
litellm_params={},
|
||||||
|
)
|
||||||
|
# Assert both GET and POST endpoints were called
|
||||||
|
assert get_mock.calls.call_count == 1
|
||||||
|
assert post_mock.calls.call_count == 1
|
||||||
|
assert json.loads(post_mock.calls[0].request.content) == {
|
||||||
|
"contents": [],
|
||||||
|
"model": "models/gemini-2.0-flash-001",
|
||||||
|
"displayName": "203ae753b6c793e1af13b13d0710de5863c486e610963ce243b07ee6830ce1d2",
|
||||||
|
"tools": None,
|
||||||
|
"toolConfig": None,
|
||||||
|
"system_instruction": {"parts": [{"text": "you are a helpful assistant"}]},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert transformed_request["contents"] == [
|
||||||
|
{"parts": [{"text": "Hello, world!"}], "role": "user"}
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
transformed_request["cachedContent"]
|
||||||
|
== "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
_GET_WEATHER_MESSAGES = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "you are a helpful assistant",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather now?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
_GET_WEATHER_TOOLS_OPTIONAL_PARAMS = {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"functionDeclarations": [
|
||||||
|
{"name": "get_weather", "description": "Get the current weather"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_choice": {
|
||||||
|
"functionCallingConfig": {
|
||||||
|
"mode": "ANY"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_EXPECTED_GET_WEATHER_CACHED_CONTENT_REQUEST_BODY = {
|
||||||
|
"contents": [],
|
||||||
|
"model": "models/gemini-2.0-flash-001",
|
||||||
|
"displayName": "62398619ff33908a18561c1a342c580c3d876f169d103ec52128df38f04e03d1",
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"functionDeclarations": [
|
||||||
|
{"name": "get_weather", "description": "Get the current weather"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"toolConfig": {
|
||||||
|
"functionCallingConfig": {
|
||||||
|
"mode": "ANY"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"system_instruction": {"parts": [{"text": "you are a helpful assistant"}]},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_google_ai_studio_gemini_message_caching_with_tools_sync(
|
||||||
|
respx_mock: respx.MockRouter,
|
||||||
|
):
|
||||||
|
mock_logging = _mock_logging()
|
||||||
|
|
||||||
|
get_mock, post_mock = _mock_get_post_cached_content("fake_api_key", respx_mock)
|
||||||
|
|
||||||
|
transformed_request = sync_transform_request_body(
|
||||||
|
gemini_api_key="fake_api_key",
|
||||||
|
messages=_GET_WEATHER_MESSAGES,
|
||||||
|
api_base=None,
|
||||||
|
model="gemini-2.0-flash-001",
|
||||||
|
client=None,
|
||||||
|
timeout=None,
|
||||||
|
extra_headers=None,
|
||||||
|
optional_params=_GET_WEATHER_TOOLS_OPTIONAL_PARAMS,
|
||||||
|
logging_obj=mock_logging,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
litellm_params={},
|
||||||
|
)
|
||||||
|
# Assert both GET and POST endpoints were called
|
||||||
|
assert get_mock.calls.call_count == 1
|
||||||
|
assert post_mock.calls.call_count == 1
|
||||||
|
assert json.loads(post_mock.calls[0].request.content) == _EXPECTED_GET_WEATHER_CACHED_CONTENT_REQUEST_BODY
|
||||||
|
|
||||||
|
assert transformed_request["contents"] == [
|
||||||
|
{"parts": [{"text": "What is the weather now?"}], "role": "user"}
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
transformed_request["cachedContent"]
|
||||||
|
== "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id"
|
||||||
|
)
|
||||||
|
assert transformed_request.get("tools") is None
|
||||||
|
assert transformed_request.get("tool_choice") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_google_ai_studio_gemini_message_caching_with_tools_async(
|
||||||
|
respx_mock: respx.MockRouter,
|
||||||
|
):
|
||||||
|
mock_logging = _mock_logging()
|
||||||
|
|
||||||
|
get_mock, post_mock = _mock_get_post_cached_content("fake_api_key", respx_mock)
|
||||||
|
|
||||||
|
transformed_request = await async_transform_request_body(
|
||||||
|
gemini_api_key="fake_api_key",
|
||||||
|
messages=_GET_WEATHER_MESSAGES,
|
||||||
|
api_base=None,
|
||||||
|
model="gemini-2.0-flash-001",
|
||||||
|
client=None,
|
||||||
|
timeout=None,
|
||||||
|
extra_headers=None,
|
||||||
|
optional_params=_GET_WEATHER_TOOLS_OPTIONAL_PARAMS,
|
||||||
|
logging_obj=mock_logging,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
litellm_params={},
|
||||||
|
)
|
||||||
|
# Assert both GET and POST endpoints were called
|
||||||
|
assert get_mock.calls.call_count == 1
|
||||||
|
assert post_mock.calls.call_count == 1
|
||||||
|
assert json.loads(post_mock.calls[0].request.content) == _EXPECTED_GET_WEATHER_CACHED_CONTENT_REQUEST_BODY
|
||||||
|
|
||||||
|
assert transformed_request["contents"] == [
|
||||||
|
{"parts": [{"text": "What is the weather now?"}], "role": "user"}
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
transformed_request["cachedContent"]
|
||||||
|
== "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id"
|
||||||
|
)
|
||||||
|
assert transformed_request.get("tools") is None
|
||||||
|
assert transformed_request.get("tool_choice") is None
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue