mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 5ff8b85402
into b82af5b826
This commit is contained in:
commit
4238a2505d
5 changed files with 397 additions and 64 deletions
|
@ -1,13 +1,17 @@
|
|||
"""
|
||||
Transformation logic for context caching.
|
||||
Transformation logic for context caching.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm.caching.caching import Cache
|
||||
|
||||
from litellm.types.caching import LiteLLMCacheType
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.vertex_ai import CachedContentRequestBody
|
||||
from litellm.types.llms.vertex_ai import CachedContentRequestBody, ToolConfig, Tools
|
||||
from litellm.utils import is_cached_message
|
||||
|
||||
from ..common_utils import get_supports_system_message
|
||||
|
@ -88,7 +92,11 @@ def separate_cached_messages(
|
|||
|
||||
|
||||
def transform_openai_messages_to_gemini_context_caching(
|
||||
model: str, messages: List[AllMessageValues], cache_key: str
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
cache_key: str,
|
||||
tools: Optional[List[Tools]] = None,
|
||||
tool_choice: Optional[ToolConfig] = None,
|
||||
) -> CachedContentRequestBody:
|
||||
supports_system_message = get_supports_system_message(
|
||||
model=model, custom_llm_provider="gemini"
|
||||
|
@ -103,8 +111,126 @@ def transform_openai_messages_to_gemini_context_caching(
|
|||
contents=transformed_messages,
|
||||
model="models/{}".format(model),
|
||||
displayName=cache_key,
|
||||
tools=tools,
|
||||
toolConfig=tool_choice,
|
||||
)
|
||||
if transformed_system_messages is not None:
|
||||
data["system_instruction"] = transformed_system_messages
|
||||
|
||||
return data
|
||||
|
||||
|
||||
local_cache_obj = Cache(type=LiteLLMCacheType.LOCAL)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheSplitResult:
|
||||
"""Result of splitting messages into cacheable and non-cacheable parts"""
|
||||
|
||||
remaining_messages: List[
|
||||
AllMessageValues
|
||||
] # Messages that should be sent in actual request
|
||||
optional_params: Dict[str, Any] # Updated params to be sent in actual request
|
||||
cache_key: Optional[str] # Key to use for checking if content is already cached
|
||||
cached_content: Optional[
|
||||
str
|
||||
] # cached content ID, no further processing is needed once this is defined
|
||||
cache_request_body: Optional[
|
||||
CachedContentRequestBody
|
||||
] # Request body to create new cache if needed
|
||||
|
||||
def with_cached_content(self, cached_content: str) -> "CacheSplitResult":
|
||||
"""
|
||||
Returns an updated CacheSplitResult with the cached content applied.
|
||||
"""
|
||||
updated_params = {**self.optional_params, "cached_content": cached_content}
|
||||
return replace(
|
||||
self,
|
||||
cached_content=cached_content,
|
||||
optional_params=updated_params,
|
||||
cache_request_body=None,
|
||||
)
|
||||
|
||||
|
||||
def extract_cache_configuration(
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict[str, Any],
|
||||
) -> CacheSplitResult:
|
||||
"""
|
||||
Checks if a given request should have a cache, and if so, extracts the cache configuration, returning
|
||||
a modified version of the messages and optional params.
|
||||
|
||||
- Removes the cached content from the messages
|
||||
- Adds the cache key to the optional params
|
||||
- If there's cached content, also moves the tool call and tool choice to the optional params, as that is
|
||||
required for the cache to work. (The tools are moved into some sort of system prompt on google's side)
|
||||
|
||||
Relevant error:
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "CachedContent can not be used with GenerateContent request setting system_instruction, tools or tool_config.\n\nProposed fix: move those values to CachedContent from GenerateContent request.",
|
||||
"status": "INVALID_ARGUMENT"
|
||||
}
|
||||
|
||||
Returns:
|
||||
CacheSplitResult with:
|
||||
- remaining_messages: Messages that should be sent in the actual request
|
||||
- cache_key: The key to use for checking if content is already cached
|
||||
- cached_content: The cached content ID if already provided
|
||||
- cache_request_body: The request body to create a new cache entry if needed
|
||||
"""
|
||||
# If cached content is already provided, no need to process messages
|
||||
if (
|
||||
"cached_content" in optional_params
|
||||
and optional_params["cached_content"] is not None
|
||||
):
|
||||
return CacheSplitResult(
|
||||
remaining_messages=messages,
|
||||
optional_params=optional_params,
|
||||
cache_key=None,
|
||||
cached_content=optional_params["cached_content"],
|
||||
cache_request_body=None,
|
||||
)
|
||||
|
||||
# Separate messages that can be cached from those that can't
|
||||
cached_messages, non_cached_messages = separate_cached_messages(messages=messages)
|
||||
|
||||
|
||||
# If no messages can be cached, return original messages
|
||||
if len(cached_messages) == 0:
|
||||
return CacheSplitResult(
|
||||
remaining_messages=messages,
|
||||
optional_params=optional_params,
|
||||
cache_key=None,
|
||||
cached_content=None,
|
||||
cache_request_body=None,
|
||||
)
|
||||
if "tools" in optional_params or "tool_choice" in optional_params:
|
||||
optional_params = optional_params.copy()
|
||||
tools = optional_params.pop("tools", None)
|
||||
tool_choice = optional_params.pop("tool_choice", None)
|
||||
else:
|
||||
tools = None
|
||||
tool_choice = None
|
||||
key_kwargs = {}
|
||||
if tools is not None:
|
||||
key_kwargs["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
key_kwargs["tool_choice"] = tool_choice
|
||||
|
||||
# Generate cache key for the cacheable messages
|
||||
cache_key = local_cache_obj.get_cache_key(messages=cached_messages, **key_kwargs)
|
||||
|
||||
# Transform cached messages into request body
|
||||
cache_request_body = transform_openai_messages_to_gemini_context_caching(
|
||||
model=model, messages=cached_messages, cache_key=cache_key, tools=tools, tool_choice=tool_choice
|
||||
)
|
||||
|
||||
return CacheSplitResult(
|
||||
remaining_messages=non_cached_messages,
|
||||
optional_params=optional_params,
|
||||
cache_key=cache_key,
|
||||
cached_content=None,
|
||||
cache_request_body=cache_request_body,
|
||||
)
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union, Dict, Any
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.caching.caching import Cache, LiteLLMCacheType
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
|
@ -19,15 +18,10 @@ from litellm.types.llms.vertex_ai import (
|
|||
from ..common_utils import VertexAIError
|
||||
from ..vertex_llm_base import VertexBase
|
||||
from .transformation import (
|
||||
separate_cached_messages,
|
||||
transform_openai_messages_to_gemini_context_caching,
|
||||
CacheSplitResult,
|
||||
extract_cache_configuration,
|
||||
)
|
||||
|
||||
local_cache_obj = Cache(
|
||||
type=LiteLLMCacheType.LOCAL
|
||||
) # only used for calling 'get_cache_key' function
|
||||
|
||||
|
||||
class ContextCachingEndpoints(VertexBase):
|
||||
"""
|
||||
Covers context caching endpoints for Vertex AI + Google AI Studio
|
||||
|
@ -205,6 +199,7 @@ class ContextCachingEndpoints(VertexBase):
|
|||
def check_and_create_cache(
|
||||
self,
|
||||
messages: List[AllMessageValues], # receives openai format messages
|
||||
optional_params: Dict[str, Any],
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
|
@ -212,8 +207,7 @@ class ContextCachingEndpoints(VertexBase):
|
|||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: Logging,
|
||||
extra_headers: Optional[dict] = None,
|
||||
cached_content: Optional[str] = None,
|
||||
) -> Tuple[List[AllMessageValues], Optional[str]]:
|
||||
) -> CacheSplitResult:
|
||||
"""
|
||||
Receives
|
||||
- messages: List of dict - messages in the openai format
|
||||
|
@ -224,8 +218,19 @@ class ContextCachingEndpoints(VertexBase):
|
|||
|
||||
Follows - https://ai.google.dev/api/caching#request-body
|
||||
"""
|
||||
if cached_content is not None:
|
||||
return messages, cached_content
|
||||
|
||||
cache_split_result = extract_cache_configuration(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
if (
|
||||
cache_split_result.cache_request_body is None
|
||||
or cache_split_result.cached_content is not None
|
||||
or cache_split_result.cache_key is None
|
||||
):
|
||||
return cache_split_result
|
||||
|
||||
## AUTHORIZATION ##
|
||||
token, url = self._get_token_and_url_context_caching(
|
||||
|
@ -252,17 +257,9 @@ class ContextCachingEndpoints(VertexBase):
|
|||
else:
|
||||
client = client
|
||||
|
||||
cached_messages, non_cached_messages = separate_cached_messages(
|
||||
messages=messages
|
||||
)
|
||||
|
||||
if len(cached_messages) == 0:
|
||||
return messages, None
|
||||
|
||||
## CHECK IF CACHED ALREADY
|
||||
generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
|
||||
google_cache_name = self.check_cache(
|
||||
cache_key=generated_cache_key,
|
||||
cache_key=cache_split_result.cache_key,
|
||||
client=client,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
|
@ -270,21 +267,16 @@ class ContextCachingEndpoints(VertexBase):
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
if google_cache_name:
|
||||
return non_cached_messages, google_cache_name
|
||||
return cache_split_result.with_cached_content(cached_content=google_cache_name)
|
||||
|
||||
## TRANSFORM REQUEST
|
||||
cached_content_request_body = (
|
||||
transform_openai_messages_to_gemini_context_caching(
|
||||
model=model, messages=cached_messages, cache_key=generated_cache_key
|
||||
)
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": cached_content_request_body,
|
||||
"complete_input_dict": cache_split_result.cache_request_body,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
|
@ -292,7 +284,7 @@ class ContextCachingEndpoints(VertexBase):
|
|||
|
||||
try:
|
||||
response = client.post(
|
||||
url=url, headers=headers, json=cached_content_request_body # type: ignore
|
||||
url=url, headers=headers, json=cache_split_result.cache_request_body # type: ignore
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
|
@ -305,11 +297,12 @@ class ContextCachingEndpoints(VertexBase):
|
|||
cached_content_response_obj = VertexAICachedContentResponseObject(
|
||||
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
||||
)
|
||||
return (non_cached_messages, cached_content_response_obj["name"])
|
||||
return cache_split_result.with_cached_content(cached_content=cached_content_response_obj["name"])
|
||||
|
||||
async def async_check_and_create_cache(
|
||||
self,
|
||||
messages: List[AllMessageValues], # receives openai format messages
|
||||
optional_params: Dict[str, Any],
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
|
@ -317,8 +310,7 @@ class ContextCachingEndpoints(VertexBase):
|
|||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: Logging,
|
||||
extra_headers: Optional[dict] = None,
|
||||
cached_content: Optional[str] = None,
|
||||
) -> Tuple[List[AllMessageValues], Optional[str]]:
|
||||
) -> CacheSplitResult:
|
||||
"""
|
||||
Receives
|
||||
- messages: List of dict - messages in the openai format
|
||||
|
@ -329,15 +321,19 @@ class ContextCachingEndpoints(VertexBase):
|
|||
|
||||
Follows - https://ai.google.dev/api/caching#request-body
|
||||
"""
|
||||
if cached_content is not None:
|
||||
return messages, cached_content
|
||||
|
||||
cached_messages, non_cached_messages = separate_cached_messages(
|
||||
messages=messages
|
||||
cache_split_result = extract_cache_configuration(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
if len(cached_messages) == 0:
|
||||
return messages, None
|
||||
if (
|
||||
cache_split_result.cache_request_body is None
|
||||
or cache_split_result.cached_content is not None
|
||||
or cache_split_result.cache_key is None
|
||||
):
|
||||
return cache_split_result
|
||||
|
||||
## AUTHORIZATION ##
|
||||
token, url = self._get_token_and_url_context_caching(
|
||||
|
@ -362,9 +358,8 @@ class ContextCachingEndpoints(VertexBase):
|
|||
client = client
|
||||
|
||||
## CHECK IF CACHED ALREADY
|
||||
generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
|
||||
google_cache_name = await self.async_check_cache(
|
||||
cache_key=generated_cache_key,
|
||||
cache_key=cache_split_result.cache_key,
|
||||
client=client,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
|
@ -372,21 +367,16 @@ class ContextCachingEndpoints(VertexBase):
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
if google_cache_name:
|
||||
return non_cached_messages, google_cache_name
|
||||
return cache_split_result.with_cached_content(cached_content=google_cache_name)
|
||||
|
||||
## TRANSFORM REQUEST
|
||||
cached_content_request_body = (
|
||||
transform_openai_messages_to_gemini_context_caching(
|
||||
model=model, messages=cached_messages, cache_key=generated_cache_key
|
||||
)
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": cached_content_request_body,
|
||||
"complete_input_dict": cache_split_result.cache_request_body,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
|
@ -394,7 +384,7 @@ class ContextCachingEndpoints(VertexBase):
|
|||
|
||||
try:
|
||||
response = await client.post(
|
||||
url=url, headers=headers, json=cached_content_request_body # type: ignore
|
||||
url=url, headers=headers, json=cache_split_result.cache_request_body # type: ignore
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
|
@ -407,7 +397,7 @@ class ContextCachingEndpoints(VertexBase):
|
|||
cached_content_response_obj = VertexAICachedContentResponseObject(
|
||||
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
||||
)
|
||||
return (non_cached_messages, cached_content_response_obj["name"])
|
||||
return cache_split_result.with_cached_content(cached_content=cached_content_response_obj["name"])
|
||||
|
||||
def get_cache(self):
|
||||
pass
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Transformation logic from OpenAI format to Gemini format.
|
||||
Transformation logic from OpenAI format to Gemini format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
@ -384,17 +384,20 @@ def sync_transform_request_body(
|
|||
context_caching_endpoints = ContextCachingEndpoints()
|
||||
|
||||
if gemini_api_key is not None:
|
||||
messages, cached_content = context_caching_endpoints.check_and_create_cache(
|
||||
cache_split_result = context_caching_endpoints.check_and_create_cache(
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=gemini_api_key,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
cached_content=optional_params.pop("cached_content", None),
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
messages = cache_split_result.remaining_messages
|
||||
cached_content = cache_split_result.cached_content
|
||||
optional_params = cache_split_result.optional_params
|
||||
else: # [TODO] implement context caching for gemini as well
|
||||
cached_content = optional_params.pop("cached_content", None)
|
||||
|
||||
|
@ -426,20 +429,20 @@ async def async_transform_request_body(
|
|||
context_caching_endpoints = ContextCachingEndpoints()
|
||||
|
||||
if gemini_api_key is not None:
|
||||
(
|
||||
messages,
|
||||
cached_content,
|
||||
) = await context_caching_endpoints.async_check_and_create_cache(
|
||||
cache_split_result = await context_caching_endpoints.async_check_and_create_cache(
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=gemini_api_key,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
cached_content=optional_params.pop("cached_content", None),
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
messages = cache_split_result.remaining_messages
|
||||
cached_content = cache_split_result.cached_content
|
||||
optional_params = cache_split_result.optional_params
|
||||
else: # [TODO] implement context caching for gemini as well
|
||||
cached_content = optional_params.pop("cached_content", None)
|
||||
|
||||
|
|
|
@ -252,8 +252,8 @@ class RequestBody(TypedDict, total=False):
|
|||
class CachedContentRequestBody(TypedDict, total=False):
|
||||
contents: Required[List[ContentType]]
|
||||
system_instruction: SystemInstructions
|
||||
tools: Tools
|
||||
toolConfig: ToolConfig
|
||||
tools: Optional[List[Tools]]
|
||||
toolConfig: Optional[ToolConfig]
|
||||
model: Required[str] # Format: models/{model}
|
||||
ttl: str # ending in 's' - Example: "3.5s".
|
||||
displayName: str
|
||||
|
|
|
@ -1,14 +1,21 @@
|
|||
import asyncio
|
||||
from typing import List, cast
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||
from litellm.llms.gemini.chat.transformation import GoogleAIStudioGeminiConfig
|
||||
from litellm.llms.vertex_ai.gemini.transformation import async_transform_request_body, sync_transform_request_body
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
VertexLLM,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import UsageMetadata
|
||||
from litellm.types.utils import ChoiceLogprobs, Usage
|
||||
|
@ -242,6 +249,213 @@ def test_vertex_ai_thinking_output_part():
|
|||
assert reasoning_content == "I'm thinking..."
|
||||
|
||||
|
||||
def _mock_logging():
|
||||
mock_logging = MagicMock(spec=LiteLLMLoggingObj)
|
||||
mock_logging.pre_call = MagicMock(return_value=None)
|
||||
mock_logging.post_call = MagicMock(return_value=None)
|
||||
return mock_logging
|
||||
|
||||
def _mock_get_post_cached_content(api_key: str, respx_mock: respx.MockRouter) -> tuple[respx.MockRouter, respx.MockRouter]:
|
||||
get_mock = respx_mock.get(
|
||||
f"https://generativelanguage.googleapis.com/v1beta/cachedContents?key={api_key}"
|
||||
).respond(
|
||||
json={
|
||||
"cachedContents": [],
|
||||
"nextPageToken": None,
|
||||
}
|
||||
)
|
||||
|
||||
post_mock = respx_mock.post(
|
||||
f"https://generativelanguage.googleapis.com/v1beta/cachedContents?key={api_key}"
|
||||
).respond(
|
||||
json={
|
||||
"name": "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id",
|
||||
"model": "gemini-2.0-flash-001",
|
||||
}
|
||||
)
|
||||
return get_mock, post_mock
|
||||
|
||||
def test_google_ai_studio_gemini_message_caching_sync(
|
||||
# ideally this would unit test just a small transformation, but there's a lot going on with gemini/vertex
|
||||
# (hinges around branching for sync/async transformations).
|
||||
respx_mock: respx.MockRouter,
|
||||
):
|
||||
mock_logging = _mock_logging()
|
||||
|
||||
get_mock, post_mock = _mock_get_post_cached_content("fake_api_key", respx_mock)
|
||||
|
||||
transformed_request = sync_transform_request_body(
|
||||
gemini_api_key="fake_api_key",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "you are a helpful assistant",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, world!",
|
||||
},
|
||||
],
|
||||
api_base=None,
|
||||
model="gemini-2.0-flash-001",
|
||||
client=None,
|
||||
timeout=None,
|
||||
extra_headers=None,
|
||||
optional_params={},
|
||||
logging_obj=mock_logging,
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params={},
|
||||
)
|
||||
# Assert both GET and POST endpoints were called
|
||||
assert get_mock.calls.call_count == 1
|
||||
assert post_mock.calls.call_count == 1
|
||||
assert json.loads(post_mock.calls[0].request.content) == {
|
||||
"contents": [],
|
||||
"model": "models/gemini-2.0-flash-001",
|
||||
"displayName": "203ae753b6c793e1af13b13d0710de5863c486e610963ce243b07ee6830ce1d2",
|
||||
"tools": None,
|
||||
"toolConfig": None,
|
||||
"system_instruction": {"parts": [{"text": "you are a helpful assistant"}]},
|
||||
}
|
||||
|
||||
assert transformed_request["contents"] == [
|
||||
{"parts": [{"text": "Hello, world!"}], "role": "user"}
|
||||
]
|
||||
assert (
|
||||
transformed_request["cachedContent"]
|
||||
== "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id"
|
||||
)
|
||||
|
||||
_GET_WEATHER_MESSAGES = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "you are a helpful assistant",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather now?",
|
||||
},
|
||||
]
|
||||
|
||||
_GET_WEATHER_TOOLS_OPTIONAL_PARAMS = {
|
||||
"tools": [
|
||||
{
|
||||
"functionDeclarations": [
|
||||
{"name": "get_weather", "description": "Get the current weather"}
|
||||
],
|
||||
}
|
||||
],
|
||||
"tool_choice": {
|
||||
"functionCallingConfig": {
|
||||
"mode": "ANY"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_EXPECTED_GET_WEATHER_CACHED_CONTENT_REQUEST_BODY = {
|
||||
"contents": [],
|
||||
"model": "models/gemini-2.0-flash-001",
|
||||
"displayName": "62398619ff33908a18561c1a342c580c3d876f169d103ec52128df38f04e03d1",
|
||||
"tools": [
|
||||
{
|
||||
"functionDeclarations": [
|
||||
{"name": "get_weather", "description": "Get the current weather"}
|
||||
],
|
||||
}
|
||||
],
|
||||
"toolConfig": {
|
||||
"functionCallingConfig": {
|
||||
"mode": "ANY"
|
||||
}
|
||||
},
|
||||
"system_instruction": {"parts": [{"text": "you are a helpful assistant"}]},
|
||||
}
|
||||
|
||||
def test_google_ai_studio_gemini_message_caching_with_tools_sync(
|
||||
respx_mock: respx.MockRouter,
|
||||
):
|
||||
mock_logging = _mock_logging()
|
||||
|
||||
get_mock, post_mock = _mock_get_post_cached_content("fake_api_key", respx_mock)
|
||||
|
||||
transformed_request = sync_transform_request_body(
|
||||
gemini_api_key="fake_api_key",
|
||||
messages=_GET_WEATHER_MESSAGES,
|
||||
api_base=None,
|
||||
model="gemini-2.0-flash-001",
|
||||
client=None,
|
||||
timeout=None,
|
||||
extra_headers=None,
|
||||
optional_params=_GET_WEATHER_TOOLS_OPTIONAL_PARAMS,
|
||||
logging_obj=mock_logging,
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params={},
|
||||
)
|
||||
# Assert both GET and POST endpoints were called
|
||||
assert get_mock.calls.call_count == 1
|
||||
assert post_mock.calls.call_count == 1
|
||||
assert json.loads(post_mock.calls[0].request.content) == _EXPECTED_GET_WEATHER_CACHED_CONTENT_REQUEST_BODY
|
||||
|
||||
assert transformed_request["contents"] == [
|
||||
{"parts": [{"text": "What is the weather now?"}], "role": "user"}
|
||||
]
|
||||
assert (
|
||||
transformed_request["cachedContent"]
|
||||
== "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id"
|
||||
)
|
||||
assert transformed_request.get("tools") is None
|
||||
assert transformed_request.get("tool_choice") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_ai_studio_gemini_message_caching_with_tools_async(
|
||||
respx_mock: respx.MockRouter,
|
||||
):
|
||||
mock_logging = _mock_logging()
|
||||
|
||||
get_mock, post_mock = _mock_get_post_cached_content("fake_api_key", respx_mock)
|
||||
|
||||
transformed_request = await async_transform_request_body(
|
||||
gemini_api_key="fake_api_key",
|
||||
messages=_GET_WEATHER_MESSAGES,
|
||||
api_base=None,
|
||||
model="gemini-2.0-flash-001",
|
||||
client=None,
|
||||
timeout=None,
|
||||
extra_headers=None,
|
||||
optional_params=_GET_WEATHER_TOOLS_OPTIONAL_PARAMS,
|
||||
logging_obj=mock_logging,
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params={},
|
||||
)
|
||||
# Assert both GET and POST endpoints were called
|
||||
assert get_mock.calls.call_count == 1
|
||||
assert post_mock.calls.call_count == 1
|
||||
assert json.loads(post_mock.calls[0].request.content) == _EXPECTED_GET_WEATHER_CACHED_CONTENT_REQUEST_BODY
|
||||
|
||||
assert transformed_request["contents"] == [
|
||||
{"parts": [{"text": "What is the weather now?"}], "role": "user"}
|
||||
]
|
||||
assert (
|
||||
transformed_request["cachedContent"]
|
||||
== "projects/fake_project/locations/fake_location/cachedContents/fake_cache_id"
|
||||
)
|
||||
assert transformed_request.get("tools") is None
|
||||
assert transformed_request.get("tool_choice") is None
|
||||
|
||||
|
||||
def test_vertex_ai_empty_content():
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue