Fix #8296 - Modify gemini cache configuration to also move tools to the cache

This commit is contained in:
Adrian Lyjak 2025-04-19 00:32:10 -04:00
parent 3c463f6715
commit 7d5e7b6c13
5 changed files with 398 additions and 64 deletions

View file

@ -1,13 +1,17 @@
""" """
Transformation logic for context caching. Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works
""" """
from typing import List, Tuple from dataclasses import dataclass, replace
from typing import Any, Dict, List, Optional, Tuple
from litellm.caching.caching import Cache
from litellm.types.caching import LiteLLMCacheType
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody from litellm.types.llms.vertex_ai import CachedContentRequestBody, ToolConfig, Tools
from litellm.utils import is_cached_message from litellm.utils import is_cached_message
from ..common_utils import get_supports_system_message from ..common_utils import get_supports_system_message
@ -88,7 +92,11 @@ def separate_cached_messages(
def transform_openai_messages_to_gemini_context_caching( def transform_openai_messages_to_gemini_context_caching(
model: str, messages: List[AllMessageValues], cache_key: str model: str,
messages: List[AllMessageValues],
cache_key: str,
tools: Optional[List[Tools]] = None,
tool_choice: Optional[ToolConfig] = None,
) -> CachedContentRequestBody: ) -> CachedContentRequestBody:
supports_system_message = get_supports_system_message( supports_system_message = get_supports_system_message(
model=model, custom_llm_provider="gemini" model=model, custom_llm_provider="gemini"
@ -103,8 +111,126 @@ def transform_openai_messages_to_gemini_context_caching(
contents=transformed_messages, contents=transformed_messages,
model="models/{}".format(model), model="models/{}".format(model),
displayName=cache_key, displayName=cache_key,
tools=tools,
toolConfig=tool_choice,
) )
if transformed_system_messages is not None: if transformed_system_messages is not None:
data["system_instruction"] = transformed_system_messages data["system_instruction"] = transformed_system_messages
return data return data
local_cache_obj = Cache(type=LiteLLMCacheType.LOCAL)
@dataclass(frozen=True)
class CacheSplitResult:
"""Result of splitting messages into cacheable and non-cacheable parts"""
remaining_messages: List[
AllMessageValues
] # Messages that should be sent in actual request
optional_params: Dict[str, Any] # Updated params to be sent in actual request
cache_key: Optional[str] # Key to use for checking if content is already cached
cached_content: Optional[
str
] # cached content ID, no further processing is needed once this is defined
cache_request_body: Optional[
CachedContentRequestBody
] # Request body to create new cache if needed
def with_cached_content(self, cached_content: str) -> "CacheSplitResult":
"""
Returns an updated CacheSplitResult with the cached content applied.
"""
updated_params = {**self.optional_params, "cached_content": cached_content}
return replace(
self,
cached_content=cached_content,
optional_params=updated_params,
cache_request_body=None,
)
def extract_cache_configuration(
model: str,
messages: List[AllMessageValues],
optional_params: Dict[str, Any],
) -> CacheSplitResult:
"""
Checks if a given request should have a cache, and if so, extracts the cache configuration, returning
a modified version of the messages and optional params.
- Removes the cached content from the messages
- Adds the cache key to the optional params
- If there's cached content, also moves the tool call and tool choice to the optional params, as that is
required for the cache to work. (The tools are moved into some sort of system prompt on google's side)
Relevant error:
"error": {
"code": 400,
"message": "CachedContent can not be used with GenerateContent request setting system_instruction, tools or tool_config.\n\nProposed fix: move those values to CachedContent from GenerateContent request.",
"status": "INVALID_ARGUMENT"
}
Returns:
CacheSplitResult with:
- remaining_messages: Messages that should be sent in the actual request
- cache_key: The key to use for checking if content is already cached
- cached_content: The cached content ID if already provided
- cache_request_body: The request body to create a new cache entry if needed
"""
# If cached content is already provided, no need to process messages
if (
"cached_content" in optional_params
and optional_params["cached_content"] is not None
):
return CacheSplitResult(
remaining_messages=messages,
optional_params=optional_params,
cache_key=None,
cached_content=optional_params["cached_content"],
cache_request_body=None,
)
# Separate messages that can be cached from those that can't
cached_messages, non_cached_messages = separate_cached_messages(messages=messages)
# If no messages can be cached, return original messages
if len(cached_messages) == 0:
return CacheSplitResult(
remaining_messages=messages,
optional_params=optional_params,
cache_key=None,
cached_content=None,
cache_request_body=None,
)
if "tools" in optional_params or "tool_choice" in optional_params:
optional_params = optional_params.copy()
tools = optional_params.pop("tools", None)
tool_choice = optional_params.pop("tool_choice", None)
else:
tools = None
tool_choice = None
key_kwargs = {}
if tools is not None:
key_kwargs["tools"] = tools
if tool_choice is not None:
key_kwargs["tool_choice"] = tool_choice
# Generate cache key for the cacheable messages
cache_key = local_cache_obj.get_cache_key(messages=cached_messages, **key_kwargs)
# Transform cached messages into request body
cache_request_body = transform_openai_messages_to_gemini_context_caching(
model=model, messages=cached_messages, cache_key=cache_key, tools=tools, tool_choice=tool_choice
)
return CacheSplitResult(
remaining_messages=non_cached_messages,
optional_params=optional_params,
cache_key=cache_key,
cached_content=None,
cache_request_body=cache_request_body,
)

View file

@ -1,9 +1,8 @@
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union, Dict, Any
import httpx import httpx
import litellm import litellm
from litellm.caching.caching import Cache, LiteLLMCacheType
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
@ -19,15 +18,10 @@ from litellm.types.llms.vertex_ai import (
from ..common_utils import VertexAIError from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase from ..vertex_llm_base import VertexBase
from .transformation import ( from .transformation import (
separate_cached_messages, CacheSplitResult,
transform_openai_messages_to_gemini_context_caching, extract_cache_configuration,
) )
local_cache_obj = Cache(
type=LiteLLMCacheType.LOCAL
) # only used for calling 'get_cache_key' function
class ContextCachingEndpoints(VertexBase): class ContextCachingEndpoints(VertexBase):
""" """
Covers context caching endpoints for Vertex AI + Google AI Studio Covers context caching endpoints for Vertex AI + Google AI Studio
@ -205,6 +199,7 @@ class ContextCachingEndpoints(VertexBase):
def check_and_create_cache( def check_and_create_cache(
self, self,
messages: List[AllMessageValues], # receives openai format messages messages: List[AllMessageValues], # receives openai format messages
optional_params: Dict[str, Any],
api_key: str, api_key: str,
api_base: Optional[str], api_base: Optional[str],
model: str, model: str,
@ -212,8 +207,7 @@ class ContextCachingEndpoints(VertexBase):
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: Logging, logging_obj: Logging,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
cached_content: Optional[str] = None, ) -> CacheSplitResult:
) -> Tuple[List[AllMessageValues], Optional[str]]:
""" """
Receives Receives
- messages: List of dict - messages in the openai format - messages: List of dict - messages in the openai format
@ -224,8 +218,19 @@ class ContextCachingEndpoints(VertexBase):
Follows - https://ai.google.dev/api/caching#request-body Follows - https://ai.google.dev/api/caching#request-body
""" """
if cached_content is not None:
return messages, cached_content cache_split_result = extract_cache_configuration(
model=model,
messages=messages,
optional_params=optional_params,
)
if (
cache_split_result.cache_request_body is None
or cache_split_result.cached_content is not None
or cache_split_result.cache_key is None
):
return cache_split_result
## AUTHORIZATION ## ## AUTHORIZATION ##
token, url = self._get_token_and_url_context_caching( token, url = self._get_token_and_url_context_caching(
@ -252,17 +257,9 @@ class ContextCachingEndpoints(VertexBase):
else: else:
client = client client = client
cached_messages, non_cached_messages = separate_cached_messages(
messages=messages
)
if len(cached_messages) == 0:
return messages, None
## CHECK IF CACHED ALREADY ## CHECK IF CACHED ALREADY
generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
google_cache_name = self.check_cache( google_cache_name = self.check_cache(
cache_key=generated_cache_key, cache_key=cache_split_result.cache_key,
client=client, client=client,
headers=headers, headers=headers,
api_key=api_key, api_key=api_key,
@ -270,21 +267,16 @@ class ContextCachingEndpoints(VertexBase):
logging_obj=logging_obj, logging_obj=logging_obj,
) )
if google_cache_name: if google_cache_name:
return non_cached_messages, google_cache_name return cache_split_result.with_cached_content(cached_content=google_cache_name)
## TRANSFORM REQUEST ## TRANSFORM REQUEST
cached_content_request_body = (
transform_openai_messages_to_gemini_context_caching(
model=model, messages=cached_messages, cache_key=generated_cache_key
)
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
api_key="", api_key="",
additional_args={ additional_args={
"complete_input_dict": cached_content_request_body, "complete_input_dict": cache_split_result.cache_request_body,
"api_base": url, "api_base": url,
"headers": headers, "headers": headers,
}, },
@ -292,7 +284,7 @@ class ContextCachingEndpoints(VertexBase):
try: try:
response = client.post( response = client.post(
url=url, headers=headers, json=cached_content_request_body # type: ignore url=url, headers=headers, json=cache_split_result.cache_request_body # type: ignore
) )
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
@ -305,11 +297,12 @@ class ContextCachingEndpoints(VertexBase):
cached_content_response_obj = VertexAICachedContentResponseObject( cached_content_response_obj = VertexAICachedContentResponseObject(
name=raw_response_cached.get("name"), model=raw_response_cached.get("model") name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
) )
return (non_cached_messages, cached_content_response_obj["name"]) return cache_split_result.with_cached_content(cached_content=cached_content_response_obj["name"])
async def async_check_and_create_cache( async def async_check_and_create_cache(
self, self,
messages: List[AllMessageValues], # receives openai format messages messages: List[AllMessageValues], # receives openai format messages
optional_params: Dict[str, Any],
api_key: str, api_key: str,
api_base: Optional[str], api_base: Optional[str],
model: str, model: str,
@ -317,8 +310,7 @@ class ContextCachingEndpoints(VertexBase):
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: Logging, logging_obj: Logging,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
cached_content: Optional[str] = None, ) -> CacheSplitResult:
) -> Tuple[List[AllMessageValues], Optional[str]]:
""" """
Receives Receives
- messages: List of dict - messages in the openai format - messages: List of dict - messages in the openai format
@ -329,15 +321,19 @@ class ContextCachingEndpoints(VertexBase):
Follows - https://ai.google.dev/api/caching#request-body Follows - https://ai.google.dev/api/caching#request-body
""" """
if cached_content is not None:
return messages, cached_content
cached_messages, non_cached_messages = separate_cached_messages( cache_split_result = extract_cache_configuration(
messages=messages model=model,
messages=messages,
optional_params=optional_params,
) )
if len(cached_messages) == 0: if (
return messages, None cache_split_result.cache_request_body is None
or cache_split_result.cached_content is not None
or cache_split_result.cache_key is None
):
return cache_split_result
## AUTHORIZATION ## ## AUTHORIZATION ##
token, url = self._get_token_and_url_context_caching( token, url = self._get_token_and_url_context_caching(
@ -362,9 +358,8 @@ class ContextCachingEndpoints(VertexBase):
client = client client = client
## CHECK IF CACHED ALREADY ## CHECK IF CACHED ALREADY
generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
google_cache_name = await self.async_check_cache( google_cache_name = await self.async_check_cache(
cache_key=generated_cache_key, cache_key=cache_split_result.cache_key,
client=client, client=client,
headers=headers, headers=headers,
api_key=api_key, api_key=api_key,
@ -372,21 +367,16 @@ class ContextCachingEndpoints(VertexBase):
logging_obj=logging_obj, logging_obj=logging_obj,
) )
if google_cache_name: if google_cache_name:
return non_cached_messages, google_cache_name return cache_split_result.with_cached_content(cached_content=google_cache_name)
## TRANSFORM REQUEST ## TRANSFORM REQUEST
cached_content_request_body = (
transform_openai_messages_to_gemini_context_caching(
model=model, messages=cached_messages, cache_key=generated_cache_key
)
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
api_key="", api_key="",
additional_args={ additional_args={
"complete_input_dict": cached_content_request_body, "complete_input_dict": cache_split_result.cache_request_body,
"api_base": url, "api_base": url,
"headers": headers, "headers": headers,
}, },
@ -394,7 +384,7 @@ class ContextCachingEndpoints(VertexBase):
try: try:
response = await client.post( response = await client.post(
url=url, headers=headers, json=cached_content_request_body # type: ignore url=url, headers=headers, json=cache_split_result.cache_request_body # type: ignore
) )
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
@ -407,7 +397,7 @@ class ContextCachingEndpoints(VertexBase):
cached_content_response_obj = VertexAICachedContentResponseObject( cached_content_response_obj = VertexAICachedContentResponseObject(
name=raw_response_cached.get("name"), model=raw_response_cached.get("model") name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
) )
return (non_cached_messages, cached_content_response_obj["name"]) return cache_split_result.with_cached_content(cached_content=cached_content_response_obj["name"])
def get_cache(self): def get_cache(self):
pass pass

View file

@ -1,5 +1,5 @@
""" """
Transformation logic from OpenAI format to Gemini format. Transformation logic from OpenAI format to Gemini format.
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works
""" """
@ -431,17 +431,20 @@ def sync_transform_request_body(
context_caching_endpoints = ContextCachingEndpoints() context_caching_endpoints = ContextCachingEndpoints()
if gemini_api_key is not None: if gemini_api_key is not None:
messages, cached_content = context_caching_endpoints.check_and_create_cache( cache_split_result = context_caching_endpoints.check_and_create_cache(
messages=messages, messages=messages,
optional_params=optional_params,
api_key=gemini_api_key, api_key=gemini_api_key,
api_base=api_base, api_base=api_base,
model=model, model=model,
client=client, client=client,
timeout=timeout, timeout=timeout,
extra_headers=extra_headers, extra_headers=extra_headers,
cached_content=optional_params.pop("cached_content", None),
logging_obj=logging_obj, logging_obj=logging_obj,
) )
messages = cache_split_result.remaining_messages
cached_content = cache_split_result.cached_content
optional_params = cache_split_result.optional_params
else: # [TODO] implement context caching for gemini as well else: # [TODO] implement context caching for gemini as well
cached_content = optional_params.pop("cached_content", None) cached_content = optional_params.pop("cached_content", None)
@ -473,20 +476,20 @@ async def async_transform_request_body(
context_caching_endpoints = ContextCachingEndpoints() context_caching_endpoints = ContextCachingEndpoints()
if gemini_api_key is not None: if gemini_api_key is not None:
( cache_split_result = await context_caching_endpoints.async_check_and_create_cache(
messages,
cached_content,
) = await context_caching_endpoints.async_check_and_create_cache(
messages=messages, messages=messages,
optional_params=optional_params,
api_key=gemini_api_key, api_key=gemini_api_key,
api_base=api_base, api_base=api_base,
model=model, model=model,
client=client, client=client,
timeout=timeout, timeout=timeout,
extra_headers=extra_headers, extra_headers=extra_headers,
cached_content=optional_params.pop("cached_content", None),
logging_obj=logging_obj, logging_obj=logging_obj,
) )
messages = cache_split_result.remaining_messages
cached_content = cache_split_result.cached_content
optional_params = cache_split_result.optional_params
else: # [TODO] implement context caching for gemini as well else: # [TODO] implement context caching for gemini as well
cached_content = optional_params.pop("cached_content", None) cached_content = optional_params.pop("cached_content", None)

View file

@ -251,8 +251,8 @@ class RequestBody(TypedDict, total=False):
class CachedContentRequestBody(TypedDict, total=False): class CachedContentRequestBody(TypedDict, total=False):
contents: Required[List[ContentType]] contents: Required[List[ContentType]]
system_instruction: SystemInstructions system_instruction: SystemInstructions
tools: Tools tools: Optional[List[Tools]]
toolConfig: ToolConfig toolConfig: Optional[ToolConfig]
model: Required[str] # Format: models/{model} model: Required[str] # Format: models/{model}
ttl: str # ending in 's' - Example: "3.5s". ttl: str # ending in 's' - Example: "3.5s".
displayName: str displayName: str

View file

@ -1,14 +1,21 @@
import asyncio import asyncio
from typing import List, cast from typing import List, cast
import json
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import respx
from pydantic import BaseModel from pydantic import BaseModel
import litellm import litellm
from litellm import ModelResponse from litellm import ModelResponse
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.llms.gemini.chat.transformation import GoogleAIStudioGeminiConfig
from litellm.llms.vertex_ai.gemini.transformation import async_transform_request_body, sync_transform_request_body
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig, VertexGeminiConfig,
VertexLLM,
) )
from litellm.types.utils import ChoiceLogprobs from litellm.types.utils import ChoiceLogprobs
@ -239,3 +246,211 @@ def test_vertex_ai_thinking_output_part():
content, reasoning_content = v.get_assistant_content_message(parts=parts) content, reasoning_content = v.get_assistant_content_message(parts=parts)
assert content == "Hello world" assert content == "Hello world"
assert reasoning_content == "I'm thinking..." assert reasoning_content == "I'm thinking..."
def _mock_logging():
mock_logging = MagicMock(spec=LiteLLMLoggingObj)
mock_logging.pre_call = MagicMock(return_value=None)
mock_logging.post_call = MagicMock(return_value=None)
return mock_logging
def _mock_get_post_cached_content(api_key: str, respx_mock: respx.MockRouter) -> tuple[respx.MockRouter, respx.MockRouter]:
get_mock = respx_mock.get(
f"https://generativelanguage.googleapis.com/v1beta/cachedContents?key={api_key}"
).respond(
json={
"cachedContents": [],
"nextPageToken": None,
}
)
post_mock = respx_mock.post(
f"https://generativelanguage.googleapis.com/v1beta/cachedContents?key={api_key}"
).respond(
json={
"name": "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id",
"model": "gemini-2.0-flash-001",
}
)
return get_mock, post_mock
def test_google_ai_studio_gemini_message_caching_sync(
# ideally this would unit test just a small transformation, but there's a lot going on with gemini/vertex
# (hinges around branching for sync/async transformations).
respx_mock: respx.MockRouter,
):
mock_logging = _mock_logging()
get_mock, post_mock = _mock_get_post_cached_content("fake_api_key", respx_mock)
transformed_request = sync_transform_request_body(
gemini_api_key="fake_api_key",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "you are a helpful assistant",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "user",
"content": "Hello, world!",
},
],
api_base=None,
model="gemini-2.0-flash-001",
client=None,
timeout=None,
extra_headers=None,
optional_params={},
logging_obj=mock_logging,
custom_llm_provider="vertex_ai",
litellm_params={},
)
# Assert both GET and POST endpoints were called
assert get_mock.calls.call_count == 1
assert post_mock.calls.call_count == 1
assert json.loads(post_mock.calls[0].request.content) == {
"contents": [],
"model": "models/gemini-2.0-flash-001",
"displayName": "203ae753b6c793e1af13b13d0710de5863c486e610963ce243b07ee6830ce1d2",
"tools": None,
"toolConfig": None,
"system_instruction": {"parts": [{"text": "you are a helpful assistant"}]},
}
assert transformed_request["contents"] == [
{"parts": [{"text": "Hello, world!"}], "role": "user"}
]
assert (
transformed_request["cachedContent"]
== "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id"
)
_GET_WEATHER_MESSAGES = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "you are a helpful assistant",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "user",
"content": "What is the weather now?",
},
]
_GET_WEATHER_TOOLS_OPTIONAL_PARAMS = {
"tools": [
{
"functionDeclarations": [
{"name": "get_weather", "description": "Get the current weather"}
],
}
],
"tool_choice": {
"functionCallingConfig": {
"mode": "ANY"
}
},
}
_EXPECTED_GET_WEATHER_CACHED_CONTENT_REQUEST_BODY = {
"contents": [],
"model": "models/gemini-2.0-flash-001",
"displayName": "62398619ff33908a18561c1a342c580c3d876f169d103ec52128df38f04e03d1",
"tools": [
{
"functionDeclarations": [
{"name": "get_weather", "description": "Get the current weather"}
],
}
],
"toolConfig": {
"functionCallingConfig": {
"mode": "ANY"
}
},
"system_instruction": {"parts": [{"text": "you are a helpful assistant"}]},
}
def test_google_ai_studio_gemini_message_caching_with_tools_sync(
respx_mock: respx.MockRouter,
):
mock_logging = _mock_logging()
get_mock, post_mock = _mock_get_post_cached_content("fake_api_key", respx_mock)
transformed_request = sync_transform_request_body(
gemini_api_key="fake_api_key",
messages=_GET_WEATHER_MESSAGES,
api_base=None,
model="gemini-2.0-flash-001",
client=None,
timeout=None,
extra_headers=None,
optional_params=_GET_WEATHER_TOOLS_OPTIONAL_PARAMS,
logging_obj=mock_logging,
custom_llm_provider="vertex_ai",
litellm_params={},
)
# Assert both GET and POST endpoints were called
assert get_mock.calls.call_count == 1
assert post_mock.calls.call_count == 1
assert json.loads(post_mock.calls[0].request.content) == _EXPECTED_GET_WEATHER_CACHED_CONTENT_REQUEST_BODY
assert transformed_request["contents"] == [
{"parts": [{"text": "What is the weather now?"}], "role": "user"}
]
assert (
transformed_request["cachedContent"]
== "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id"
)
assert transformed_request.get("tools") is None
assert transformed_request.get("tool_choice") is None
@pytest.mark.asyncio
async def test_google_ai_studio_gemini_message_caching_with_tools_async(
respx_mock: respx.MockRouter,
):
mock_logging = _mock_logging()
get_mock, post_mock = _mock_get_post_cached_content("fake_api_key", respx_mock)
transformed_request = await async_transform_request_body(
gemini_api_key="fake_api_key",
messages=_GET_WEATHER_MESSAGES,
api_base=None,
model="gemini-2.0-flash-001",
client=None,
timeout=None,
extra_headers=None,
optional_params=_GET_WEATHER_TOOLS_OPTIONAL_PARAMS,
logging_obj=mock_logging,
custom_llm_provider="vertex_ai",
litellm_params={},
)
# Assert both GET and POST endpoints were called
assert get_mock.calls.call_count == 1
assert post_mock.calls.call_count == 1
assert json.loads(post_mock.calls[0].request.content) == _EXPECTED_GET_WEATHER_CACHED_CONTENT_REQUEST_BODY
assert transformed_request["contents"] == [
{"parts": [{"text": "What is the weather now?"}], "role": "user"}
]
assert (
transformed_request["cachedContent"]
== "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id"
)
assert transformed_request.get("tools") is None
assert transformed_request.get("tool_choice") is None