mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(vertex_ai_context_caching.py): support making context caching calls to vertex ai in a normal chat completion call (anthropic caching format)
Closes https://github.com/BerriAI/litellm/issues/5213
This commit is contained in:
parent
6ff17f1acd
commit
074e30fa10
16 changed files with 594 additions and 90 deletions
|
@ -848,15 +848,21 @@ from .llms.gemini import GeminiConfig
|
||||||
from .llms.nlp_cloud import NLPCloudConfig
|
from .llms.nlp_cloud import NLPCloudConfig
|
||||||
from .llms.aleph_alpha import AlephAlphaConfig
|
from .llms.aleph_alpha import AlephAlphaConfig
|
||||||
from .llms.petals import PetalsConfig
|
from .llms.petals import PetalsConfig
|
||||||
from .llms.vertex_httpx import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexGeminiConfig,
|
VertexGeminiConfig,
|
||||||
GoogleAIStudioGeminiConfig,
|
GoogleAIStudioGeminiConfig,
|
||||||
VertexAIConfig,
|
VertexAIConfig,
|
||||||
)
|
)
|
||||||
from .llms.vertex_ai import VertexAITextEmbeddingConfig
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
||||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
VertexAITextEmbeddingConfig,
|
||||||
from .llms.vertex_ai_partner import VertexAILlama3Config
|
)
|
||||||
from .llms.sagemaker.sagemaker import SagemakerConfig
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
||||||
|
VertexAIAnthropicConfig,
|
||||||
|
)
|
||||||
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
|
||||||
|
VertexAILlama3Config,
|
||||||
|
)
|
||||||
|
from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
|
|
|
@ -8,7 +8,9 @@ from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparamet
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.llms.base import BaseLLM
|
from litellm.llms.base import BaseLLM
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.vertex_httpx import VertexLLM
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexLLM,
|
||||||
|
)
|
||||||
from litellm.types.llms.openai import FineTuningJobCreate
|
from litellm.types.llms.openai import FineTuningJobCreate
|
||||||
from litellm.types.llms.vertex_ai import (
|
from litellm.types.llms.vertex_ai import (
|
||||||
FineTuneJobCreate,
|
FineTuneJobCreate,
|
||||||
|
|
|
@ -13,7 +13,9 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.llms.openai import HttpxBinaryResponseContent
|
from litellm.llms.openai import HttpxBinaryResponseContent
|
||||||
from litellm.llms.vertex_httpx import VertexLLM
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexLLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VertexInput(TypedDict, total=False):
|
class VertexInput(TypedDict, total=False):
|
||||||
|
|
39
litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py
Normal file
39
litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm import supports_system_messages, verbose_logger
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
def get_supports_system_message(
|
||||||
|
model: str, custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"]
|
||||||
|
) -> bool:
|
||||||
|
try:
|
||||||
|
_custom_llm_provider = custom_llm_provider
|
||||||
|
if custom_llm_provider == "vertex_ai_beta":
|
||||||
|
_custom_llm_provider = "vertex_ai"
|
||||||
|
supports_system_message = supports_system_messages(
|
||||||
|
model=model, custom_llm_provider=_custom_llm_provider
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.warning(
|
||||||
|
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
supports_system_message = False
|
||||||
|
|
||||||
|
return supports_system_message
|
|
@ -0,0 +1,88 @@
|
||||||
|
"""
|
||||||
|
Transformation logic for context caching.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstructions
|
||||||
|
from litellm.utils import is_cached_message
|
||||||
|
|
||||||
|
from ..common_utils import VertexAIError, get_supports_system_message
|
||||||
|
from ..gemini_transformation import transform_system_message
|
||||||
|
from ..vertex_and_google_ai_studio_gemini import _gemini_convert_messages_with_history
|
||||||
|
|
||||||
|
|
||||||
|
def separate_cached_messages(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
|
||||||
|
"""
|
||||||
|
Returns separated cached and non-cached messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages to be separated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- cached_messages: List of cached messages.
|
||||||
|
- non_cached_messages: List of non-cached messages.
|
||||||
|
"""
|
||||||
|
cached_messages: List[AllMessageValues] = []
|
||||||
|
non_cached_messages: List[AllMessageValues] = []
|
||||||
|
|
||||||
|
# Extract cached messages and their indices
|
||||||
|
filtered_messages: List[Tuple[int, AllMessageValues]] = []
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if is_cached_message(message=message):
|
||||||
|
filtered_messages.append((idx, message))
|
||||||
|
|
||||||
|
# Validate only one block of continuous cached messages
|
||||||
|
if len(filtered_messages) > 1:
|
||||||
|
expected_idx = filtered_messages[0][0] + 1
|
||||||
|
for idx, _ in filtered_messages[1:]:
|
||||||
|
if idx != expected_idx:
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=422,
|
||||||
|
message="Gemini Context Caching only supports 1 message/block of continuous messages. Your idx, messages were - {}".format(
|
||||||
|
filtered_messages
|
||||||
|
),
|
||||||
|
)
|
||||||
|
expected_idx += 1
|
||||||
|
|
||||||
|
# Separate messages based on the block of cached messages
|
||||||
|
if filtered_messages:
|
||||||
|
first_cached_idx = filtered_messages[0][0]
|
||||||
|
last_cached_idx = filtered_messages[-1][0]
|
||||||
|
|
||||||
|
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
|
||||||
|
non_cached_messages = (
|
||||||
|
messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
non_cached_messages = messages
|
||||||
|
|
||||||
|
return cached_messages, non_cached_messages
|
||||||
|
|
||||||
|
|
||||||
|
def transform_openai_messages_to_gemini_context_caching(
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
) -> CachedContentRequestBody:
|
||||||
|
supports_system_message = get_supports_system_message(
|
||||||
|
model=model, custom_llm_provider="gemini"
|
||||||
|
)
|
||||||
|
|
||||||
|
transformed_system_messages, new_messages = transform_system_message(
|
||||||
|
supports_system_message=supports_system_message, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
|
||||||
|
data = CachedContentRequestBody(
|
||||||
|
contents=transformed_messages, model="models/{}".format(model)
|
||||||
|
)
|
||||||
|
if transformed_system_messages is not None:
|
||||||
|
data["system_instruction"] = transformed_system_messages
|
||||||
|
|
||||||
|
return data
|
|
@ -0,0 +1,170 @@
|
||||||
|
import types
|
||||||
|
from typing import Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
RequestBody,
|
||||||
|
VertexAICachedContentResponseObject,
|
||||||
|
)
|
||||||
|
from litellm.utils import ModelResponse
|
||||||
|
|
||||||
|
from ..common_utils import VertexAIError
|
||||||
|
from .transformation import (
|
||||||
|
separate_cached_messages,
|
||||||
|
transform_openai_messages_to_gemini_context_caching,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextCachingEndpoints:
|
||||||
|
"""
|
||||||
|
Covers context caching endpoints for Vertex AI + Google AI Studio
|
||||||
|
|
||||||
|
v0: covers Google AI Studio
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_token_and_url(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
|
custom_llm_provider: Literal["gemini"],
|
||||||
|
api_base: Optional[str],
|
||||||
|
) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Internal function. Returns the token and url for the call.
|
||||||
|
|
||||||
|
Handles logic if it's google ai studio vs. vertex ai.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
token, url
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
_gemini_model_name = "models/{}".format(model)
|
||||||
|
auth_header = None
|
||||||
|
endpoint = "cachedContents"
|
||||||
|
url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
|
||||||
|
endpoint, gemini_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
if (
|
||||||
|
api_base is not None
|
||||||
|
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
url = "{}/{}".format(api_base, endpoint)
|
||||||
|
auth_header = (
|
||||||
|
gemini_api_key # cloudflare expects api key as bearer token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
url = "{}:{}".format(api_base, endpoint)
|
||||||
|
|
||||||
|
return auth_header, url
|
||||||
|
|
||||||
|
def create_cache(
|
||||||
|
self,
|
||||||
|
messages: List[AllMessageValues], # receives openai format messages
|
||||||
|
api_key: str,
|
||||||
|
api_base: Optional[str],
|
||||||
|
model: str,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]],
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
logging_obj: Logging,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
cached_content: Optional[str] = None,
|
||||||
|
) -> Tuple[List[AllMessageValues], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Receives
|
||||||
|
- messages: List of dict - messages in the openai format
|
||||||
|
|
||||||
|
Returns
|
||||||
|
- messages - List[dict] - filtered list of messages in the openai format.
|
||||||
|
- cached_content - str - the cache content id, to be passed in the gemini request body
|
||||||
|
|
||||||
|
Follows - https://ai.google.dev/api/caching#request-body
|
||||||
|
"""
|
||||||
|
if cached_content is not None:
|
||||||
|
return messages, cached_content
|
||||||
|
|
||||||
|
## AUTHORIZATION ##
|
||||||
|
token, url = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=api_key,
|
||||||
|
custom_llm_provider="gemini",
|
||||||
|
api_base=api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
if token is not None:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
|
|
||||||
|
cached_messages, non_cached_messages = separate_cached_messages(
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(cached_messages) == 0:
|
||||||
|
return messages, None
|
||||||
|
|
||||||
|
cached_content_request_body = (
|
||||||
|
transform_openai_messages_to_gemini_context_caching(
|
||||||
|
model=model, messages=cached_messages
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": cached_content_request_body,
|
||||||
|
"api_base": url,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.post(
|
||||||
|
url=url, headers=headers, json=cached_content_request_body # type: ignore
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
raw_response_cached = response.json()
|
||||||
|
cached_content_response_obj = VertexAICachedContentResponseObject(
|
||||||
|
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
||||||
|
)
|
||||||
|
return (non_cached_messages, cached_content_response_obj["name"])
|
||||||
|
|
||||||
|
def async_create_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_get_cache(self):
|
||||||
|
pass
|
|
@ -0,0 +1,47 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from OpenAI format to Gemini format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.llms.vertex_ai import PartType, SystemInstructions
|
||||||
|
|
||||||
|
|
||||||
|
def transform_system_message(
|
||||||
|
supports_system_message: bool, messages: List[AllMessageValues]
|
||||||
|
) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]:
|
||||||
|
"""
|
||||||
|
Extracts the system message from the openai message list.
|
||||||
|
|
||||||
|
Converts the system message to Gemini format
|
||||||
|
|
||||||
|
Returns
|
||||||
|
- system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format.
|
||||||
|
- messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately)
|
||||||
|
"""
|
||||||
|
# Separate system prompt from rest of message
|
||||||
|
system_prompt_indices = []
|
||||||
|
system_content_blocks: List[PartType] = []
|
||||||
|
if supports_system_message is True:
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
_system_content_block = PartType(text=message["content"])
|
||||||
|
elif isinstance(message["content"], list):
|
||||||
|
system_text = ""
|
||||||
|
for content in message["content"]:
|
||||||
|
system_text += content.get("text") or ""
|
||||||
|
_system_content_block = PartType(text=system_text)
|
||||||
|
system_content_blocks.append(_system_content_block)
|
||||||
|
system_prompt_indices.append(idx)
|
||||||
|
if len(system_prompt_indices) > 0:
|
||||||
|
for idx in reversed(system_prompt_indices):
|
||||||
|
messages.pop(idx)
|
||||||
|
|
||||||
|
if len(system_content_blocks) > 0:
|
||||||
|
return SystemInstructions(parts=system_content_blocks), messages
|
||||||
|
|
||||||
|
return None, messages
|
|
@ -26,7 +26,7 @@ from litellm.types.llms.openai import (
|
||||||
from litellm.types.utils import ResponseFormatChunk
|
from litellm.types.utils import ResponseFormatChunk
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from .prompt_templates.factory import (
|
from ..prompt_templates.factory import (
|
||||||
construct_tool_use_system_prompt,
|
construct_tool_use_system_prompt,
|
||||||
contains_tag,
|
contains_tag,
|
||||||
custom_prompt,
|
custom_prompt,
|
|
@ -1,41 +1,14 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Handler for calling llama 3.1 API on Vertex AI
|
## Handler for calling llama 3.1 API on Vertex AI
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import types
|
import types
|
||||||
import uuid
|
from typing import Callable, Literal, Optional, Union
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.utils import ModelResponse
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
||||||
from litellm.types.llms.anthropic import (
|
|
||||||
AnthropicMessagesTool,
|
|
||||||
AnthropicMessagesToolChoice,
|
|
||||||
)
|
|
||||||
from litellm.types.llms.openai import (
|
|
||||||
ChatCompletionToolParam,
|
|
||||||
ChatCompletionToolParamFunctionChunk,
|
|
||||||
)
|
|
||||||
from litellm.types.utils import ResponseFormatChunk
|
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
|
||||||
|
|
||||||
from .base import BaseLLM
|
from ..base import BaseLLM
|
||||||
from .prompt_templates.factory import (
|
|
||||||
construct_tool_use_system_prompt,
|
|
||||||
contains_tag,
|
|
||||||
custom_prompt,
|
|
||||||
extract_between_tags,
|
|
||||||
parse_xml_params,
|
|
||||||
prompt_factory,
|
|
||||||
response_schema_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
|
@ -9,7 +9,7 @@ import types
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
@ -25,7 +25,9 @@ from litellm.llms.prompt_templates.factory import (
|
||||||
convert_url_to_base64,
|
convert_url_to_base64,
|
||||||
response_schema_prompt,
|
response_schema_prompt,
|
||||||
)
|
)
|
||||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
||||||
|
_gemini_convert_messages_with_history,
|
||||||
|
)
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
|
@ -52,7 +54,12 @@ from litellm.types.llms.vertex_ai import (
|
||||||
from litellm.types.utils import GenericStreamingChunk
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from .base import BaseLLM
|
from ..base import BaseLLM
|
||||||
|
from .common_utils import VertexAIError, get_supports_system_message
|
||||||
|
from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints
|
||||||
|
from .gemini_transformation import transform_system_message
|
||||||
|
|
||||||
|
context_caching_endpoints = ContextCachingEndpoints()
|
||||||
|
|
||||||
|
|
||||||
class VertexAIConfig:
|
class VertexAIConfig:
|
||||||
|
@ -789,19 +796,6 @@ def make_sync_call(
|
||||||
return completion_stream
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(
|
|
||||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
|
||||||
)
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class VertexLLM(BaseLLM):
|
class VertexLLM(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1366,33 +1360,27 @@ class VertexLLM(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
try:
|
### CHECK CONTEXT CACHING ###
|
||||||
_custom_llm_provider = custom_llm_provider
|
if gemini_api_key is not None:
|
||||||
if custom_llm_provider == "vertex_ai_beta":
|
messages, cached_content = context_caching_endpoints.create_cache(
|
||||||
_custom_llm_provider = "vertex_ai"
|
messages=messages,
|
||||||
supports_system_message = litellm.supports_system_messages(
|
api_key=gemini_api_key,
|
||||||
model=model, custom_llm_provider=_custom_llm_provider
|
api_base=api_base,
|
||||||
|
model=model,
|
||||||
|
client=client,
|
||||||
|
timeout=timeout,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
cached_content=optional_params.pop("cached_content", None),
|
||||||
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.warning(
|
|
||||||
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
|
||||||
str(e)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
supports_system_message = False
|
|
||||||
# Separate system prompt from rest of message
|
|
||||||
system_prompt_indices = []
|
|
||||||
system_content_blocks: List[PartType] = []
|
|
||||||
if supports_system_message is True:
|
|
||||||
for idx, message in enumerate(messages):
|
|
||||||
if message["role"] == "system":
|
|
||||||
_system_content_block = PartType(text=message["content"])
|
|
||||||
system_content_blocks.append(_system_content_block)
|
|
||||||
system_prompt_indices.append(idx)
|
|
||||||
if len(system_prompt_indices) > 0:
|
|
||||||
for idx in reversed(system_prompt_indices):
|
|
||||||
messages.pop(idx)
|
|
||||||
|
|
||||||
|
# Separate system prompt from rest of message
|
||||||
|
supports_system_message = get_supports_system_message(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
system_instructions, messages = transform_system_message(
|
||||||
|
supports_system_message=supports_system_message, messages=messages
|
||||||
|
)
|
||||||
# Checks for 'response_schema' support - if passed in
|
# Checks for 'response_schema' support - if passed in
|
||||||
if "response_schema" in optional_params:
|
if "response_schema" in optional_params:
|
||||||
supports_response_schema = litellm.supports_response_schema(
|
supports_response_schema = litellm.supports_response_schema(
|
||||||
|
@ -1426,13 +1414,11 @@ class VertexLLM(BaseLLM):
|
||||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||||
"safety_settings", None
|
"safety_settings", None
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
cached_content: Optional[str] = optional_params.pop("cached_content", None)
|
|
||||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
data = RequestBody(contents=content)
|
data = RequestBody(contents=content)
|
||||||
if len(system_content_blocks) > 0:
|
if system_instructions is not None:
|
||||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
|
||||||
data["system_instruction"] = system_instructions
|
data["system_instruction"] = system_instructions
|
||||||
if tools is not None:
|
if tools is not None:
|
||||||
data["tools"] = tools
|
data["tools"] = tools
|
|
@ -95,8 +95,6 @@ from .llms import (
|
||||||
replicate,
|
replicate,
|
||||||
together_ai,
|
together_ai,
|
||||||
triton,
|
triton,
|
||||||
vertex_ai,
|
|
||||||
vertex_ai_anthropic,
|
|
||||||
vllm,
|
vllm,
|
||||||
watsonx,
|
watsonx,
|
||||||
)
|
)
|
||||||
|
@ -124,8 +122,16 @@ from .llms.sagemaker.sagemaker import SagemakerLLM
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI
|
from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.vertex_ai_partner import VertexAIPartnerModels
|
from .llms.vertex_ai_and_google_ai_studio import (
|
||||||
from .llms.vertex_httpx import VertexLLM
|
vertex_ai_anthropic,
|
||||||
|
vertex_ai_non_gemini,
|
||||||
|
)
|
||||||
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models import (
|
||||||
|
VertexAIPartnerModels,
|
||||||
|
)
|
||||||
|
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexLLM,
|
||||||
|
)
|
||||||
from .llms.watsonx import IBMWatsonXAI
|
from .llms.watsonx import IBMWatsonXAI
|
||||||
from .types.llms.openai import HttpxBinaryResponseContent
|
from .types.llms.openai import HttpxBinaryResponseContent
|
||||||
from .types.utils import (
|
from .types.utils import (
|
||||||
|
@ -2112,7 +2118,7 @@ def completion(
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_response = vertex_ai.completion(
|
model_response = vertex_ai_non_gemini.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
@ -3558,7 +3564,7 @@ def embedding(
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = vertex_ai.embedding(
|
response = vertex_ai_non_gemini.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
|
|
@ -28,7 +28,9 @@ from litellm import (
|
||||||
completion_cost,
|
completion_cost,
|
||||||
embedding,
|
embedding,
|
||||||
)
|
)
|
||||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||||
|
_gemini_convert_messages_with_history,
|
||||||
|
)
|
||||||
from litellm.tests.test_streaming import streaming_format_tests
|
from litellm.tests.test_streaming import streaming_format_tests
|
||||||
|
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
|
@ -2199,3 +2201,137 @@ async def test_completion_fine_tuned_model():
|
||||||
# Optional: Print for debugging
|
# Optional: Print for debugging
|
||||||
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
||||||
print("Response:", response)
|
print("Response:", response)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_gemini_request(*args, **kwargs):
|
||||||
|
print(f"kwargs: {kwargs}")
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
if "cachedContents" in kwargs["url"]:
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"name": "cachedContents/4d2kd477o3pg",
|
||||||
|
"model": "models/gemini-1.5-flash-001",
|
||||||
|
"createTime": "2024-08-26T22:31:16.147190Z",
|
||||||
|
"updateTime": "2024-08-26T22:31:16.147190Z",
|
||||||
|
"expireTime": "2024-08-26T22:36:15.548934784Z",
|
||||||
|
"displayName": "",
|
||||||
|
"usageMetadata": {"totalTokenCount": 323383},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"candidates": [
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"text": "Please provide me with the text of the legal agreement"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
"finishReason": "MAX_TOKENS",
|
||||||
|
"index": 0,
|
||||||
|
"safetyRatings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 40049,
|
||||||
|
"candidatesTokenCount": 10,
|
||||||
|
"totalTokenCount": 40059,
|
||||||
|
"cachedContentTokenCount": 40012,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_context_caching_anthropic_format():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
with patch.object(client, "post", side_effect=mock_gemini_request) as mock_client:
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gemini/gemini-1.5-flash-001",
|
||||||
|
messages=[
|
||||||
|
# System Message
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement"
|
||||||
|
* 4000,
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo",
|
||||||
|
},
|
||||||
|
# The final turn is marked with cache-control, for continuing in followups.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=10,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
assert mock_client.call_count == 2
|
||||||
|
|
||||||
|
first_call_args = mock_client.call_args_list[0].kwargs
|
||||||
|
|
||||||
|
print(f"first_call_args: {first_call_args}")
|
||||||
|
|
||||||
|
assert "cachedContents" in first_call_args["url"]
|
||||||
|
|
||||||
|
# assert "cache_read_input_tokens" in response.usage
|
||||||
|
# assert "cache_creation_input_tokens" in response.usage
|
||||||
|
|
||||||
|
# # Assert either a cache entry was created or cache was read - changes depending on the anthropic api ttl
|
||||||
|
# assert (response.usage.cache_read_input_tokens > 0) or (
|
||||||
|
# response.usage.cache_creation_input_tokens > 0
|
||||||
|
# )
|
||||||
|
|
|
@ -325,11 +325,21 @@ class ChatCompletionDeltaToolCallChunk(TypedDict, total=False):
|
||||||
index: int
|
index: int
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionTextObject(TypedDict):
|
class ChatCompletionCachedContent(TypedDict):
|
||||||
|
type: Literal["ephemeral"]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletionTextObject(TypedDict):
|
||||||
type: Literal["text"]
|
type: Literal["text"]
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionTextObject(
|
||||||
|
OpenAIChatCompletionTextObject, total=False
|
||||||
|
): # litellm wrapper on top of openai object for handling cached content
|
||||||
|
cache_control: ChatCompletionCachedContent
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionImageUrlObject(TypedDict, total=False):
|
class ChatCompletionImageUrlObject(TypedDict, total=False):
|
||||||
url: Required[str]
|
url: Required[str]
|
||||||
detail: str
|
detail: str
|
||||||
|
|
|
@ -186,6 +186,17 @@ class RequestBody(TypedDict, total=False):
|
||||||
cachedContent: str
|
cachedContent: str
|
||||||
|
|
||||||
|
|
||||||
|
class CachedContentRequestBody(TypedDict, total=False):
|
||||||
|
contents: Required[List[ContentType]]
|
||||||
|
system_instruction: SystemInstructions
|
||||||
|
tools: Tools
|
||||||
|
toolConfig: ToolConfig
|
||||||
|
model: Required[str] # Format: models/{model}
|
||||||
|
ttl: str # ending in 's' - Example: "3.5s".
|
||||||
|
name: str # Format: cachedContents/{id}
|
||||||
|
displayName: str
|
||||||
|
|
||||||
|
|
||||||
class SafetyRatings(TypedDict):
|
class SafetyRatings(TypedDict):
|
||||||
category: HarmCategory
|
category: HarmCategory
|
||||||
probability: HarmProbability
|
probability: HarmProbability
|
||||||
|
@ -320,3 +331,8 @@ class Instance(TypedDict, total=False):
|
||||||
|
|
||||||
class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
|
class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
|
||||||
instances: List[Instance]
|
instances: List[Instance]
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAICachedContentResponseObject(TypedDict):
|
||||||
|
name: str
|
||||||
|
model: str
|
||||||
|
|
|
@ -69,6 +69,7 @@ from litellm.litellm_core_utils.redact_messages import (
|
||||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
ChatCompletionNamedToolChoiceParam,
|
ChatCompletionNamedToolChoiceParam,
|
||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
)
|
)
|
||||||
|
@ -11549,3 +11550,25 @@ class ModelResponseListIterator:
|
||||||
class CustomModelResponseIterator(Iterable):
|
class CustomModelResponseIterator(Iterable):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def is_cached_message(message: AllMessageValues) -> bool:
|
||||||
|
"""
|
||||||
|
Returns true, if message is marked as needing to be cached.
|
||||||
|
|
||||||
|
Used for anthropic/gemini context caching.
|
||||||
|
|
||||||
|
Follows the anthropic format {"cache_control": {"type": "ephemeral"}}
|
||||||
|
"""
|
||||||
|
if message["content"] is None or isinstance(message["content"], str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for content in message["content"]:
|
||||||
|
if (
|
||||||
|
content["type"] == "text"
|
||||||
|
and content.get("cache_control") is not None
|
||||||
|
and content["cache_control"]["type"] == "ephemeral" # type: ignore
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue