Add bedrock latency optimized inference support (#9623)

* fix(converse_transformation.py): add performanceConfig param support on bedrock

Closes https://github.com/BerriAI/litellm/issues/7606

* fix(converse_transformation.py): refactor to use more flexible single getter for params which are separate config blocks

* test(test_main.py): add e2e mock test for bedrock performance config

* build(model_prices_and_context_window.json): add versioned multimodal embedding

* refactor(multimodal_embeddings/): migrate to config pattern

* feat(vertex_ai/multimodalembeddings): calculate usage for multimodal embedding calls

Enables cost calculation for multimodal embeddings

* feat(vertex_ai/multimodalembeddings): get usage object for embedding calls

ensures accurate cost tracking for vertexai multimodal embedding calls

* fix(embedding_handler.py): remove unused imports

* fix: fix linting errors

* fix: handle response api usage calculation

* test(test_vertex_ai_multimodal_embedding_transformation.py): update tests

* test: mark flaky test

* feat(vertex_ai/multimodal_embeddings/transformation.py): support text+image+video input

* docs(vertex.md): document sending text + image to vertex multimodal embeddings

* test: remove incorrect file

* fix(multimodal_embeddings/transformation.py): fix linting error

* style: remove unused import
This commit is contained in:
Krish Dholakia 2025-03-29 00:23:09 -07:00 committed by GitHub
parent 0742e6afd6
commit 5ac61a7572
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 806 additions and 245 deletions

View file

@ -2080,7 +2080,12 @@ print(response)
## **Multi-Modal Embeddings** ## **Multi-Modal Embeddings**
Usage
Known Limitations:
- Only supports 1 image / video / image per request
- Only supports GCS or base64 encoded images / videos
### Usage
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
@ -2296,6 +2301,115 @@ print(f"Text Embedding: {embeddings.text_embedding}")
</Tabs> </Tabs>
### Text + Image + Video Embeddings
<Tabs>
<TabItem value="sdk" label="SDK">
Text + Image
```python
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=["hey", "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"] # will be sent as a gcs image
)
```
Text + Video
```python
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=["hey", "gs://my-bucket/embeddings/supermarket-video.mp4"] # will be sent as a gcs image
)
```
Image + Video
```python
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=["gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png", "gs://my-bucket/embeddings/supermarket-video.mp4"] # will be sent as a gcs image
)
```
</TabItem>
<TabItem value="proxy" label="LiteLLM PROXY (Unified Endpoint)">
1. Add model to config.yaml
```yaml
model_list:
- model_name: multimodalembedding@001
litellm_params:
model: vertex_ai/multimodalembedding@001
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
litellm_settings:
drop_params: True
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request use OpenAI Python SDK, Langchain Python SDK
Text + Image
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
# # request sent to model set on litellm proxy, `litellm --model`
response = client.embeddings.create(
model="multimodalembedding@001",
input = ["hey", "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"],
)
print(response)
```
Text + Video
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
# # request sent to model set on litellm proxy, `litellm --model`
response = client.embeddings.create(
model="multimodalembedding@001",
input = ["hey", "gs://my-bucket/embeddings/supermarket-video.mp4"],
)
print(response)
```
Image + Video
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
# # request sent to model set on litellm proxy, `litellm --model`
response = client.embeddings.create(
model="multimodalembedding@001",
input = ["gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png", "gs://my-bucket/embeddings/supermarket-video.mp4"],
)
print(response)
```
</TabItem>
</Tabs>
## **Image Generation Models** ## **Image Generation Models**
Usage Usage

View file

@ -2,7 +2,7 @@
## File for 'response_cost' calculation in Logging ## File for 'response_cost' calculation in Logging
import time import time
from functools import lru_cache from functools import lru_cache
from typing import Any, List, Literal, Optional, Tuple, Union from typing import Any, List, Literal, Optional, Tuple, Union, cast
from pydantic import BaseModel from pydantic import BaseModel
@ -462,13 +462,36 @@ def _model_contains_known_llm_provider(model: str) -> bool:
def _get_usage_object( def _get_usage_object(
completion_response: Any, completion_response: Any,
) -> Optional[Usage]: ) -> Optional[Usage]:
usage_obj: Optional[Usage] = None usage_obj = cast(
if completion_response is not None and isinstance( Union[Usage, ResponseAPIUsage, dict, BaseModel],
completion_response, ModelResponse (
): completion_response.get("usage")
usage_obj = completion_response.get("usage") if isinstance(completion_response, dict)
else getattr(completion_response, "get", lambda x: None)("usage")
),
)
return usage_obj if usage_obj is None:
return None
if isinstance(usage_obj, Usage):
return usage_obj
elif (
usage_obj is not None
and (isinstance(usage_obj, dict) or isinstance(usage_obj, ResponseAPIUsage))
and ResponseAPILoggingUtils._is_response_api_usage(usage_obj)
):
return ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
usage_obj
)
elif isinstance(usage_obj, dict):
return Usage(**usage_obj)
elif isinstance(usage_obj, BaseModel):
return Usage(**usage_obj.model_dump())
else:
verbose_logger.debug(
f"Unknown usage object type: {type(usage_obj)}, usage_obj: {usage_obj}"
)
return None
def _is_known_usage_objects(usage_obj): def _is_known_usage_objects(usage_obj):
@ -662,6 +685,7 @@ def completion_cost( # noqa: PLR0915
elif len(prompt) > 0: elif len(prompt) > 0:
prompt_tokens = token_counter(model=model, text=prompt) prompt_tokens = token_counter(model=model, text=prompt)
completion_tokens = token_counter(model=model, text=completion) completion_tokens = token_counter(model=model, text=completion)
if model is None: if model is None:
raise ValueError( raise ValueError(
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}" f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"

View file

@ -121,6 +121,31 @@ def _get_completion_token_base_cost(model_info: ModelInfo, usage: Usage) -> floa
return model_info["output_cost_per_token"] return model_info["output_cost_per_token"]
def calculate_cost_component(
model_info: ModelInfo, cost_key: str, usage_value: Optional[float]
) -> float:
"""
Generic cost calculator for any usage component
Args:
model_info: Dictionary containing cost information
cost_key: The key for the cost multiplier in model_info (e.g., 'input_cost_per_audio_token')
usage_value: The actual usage value (e.g., number of tokens, characters, seconds)
Returns:
float: The calculated cost
"""
cost_per_unit = model_info.get(cost_key)
if (
cost_per_unit is not None
and isinstance(cost_per_unit, float)
and usage_value is not None
and usage_value > 0
):
return float(usage_value) * cost_per_unit
return 0.0
def generic_cost_per_token( def generic_cost_per_token(
model: str, usage: Usage, custom_llm_provider: str model: str, usage: Usage, custom_llm_provider: str
) -> Tuple[float, float]: ) -> Tuple[float, float]:
@ -136,6 +161,7 @@ def generic_cost_per_token(
Returns: Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
""" """
## GET MODEL INFO ## GET MODEL INFO
model_info = get_model_info(model=model, custom_llm_provider=custom_llm_provider) model_info = get_model_info(model=model, custom_llm_provider=custom_llm_provider)
@ -146,6 +172,9 @@ def generic_cost_per_token(
text_tokens = usage.prompt_tokens text_tokens = usage.prompt_tokens
cache_hit_tokens = 0 cache_hit_tokens = 0
audio_tokens = 0 audio_tokens = 0
character_count = 0
image_count = 0
video_length_seconds = 0
if usage.prompt_tokens_details: if usage.prompt_tokens_details:
cache_hit_tokens = ( cache_hit_tokens = (
cast( cast(
@ -163,6 +192,24 @@ def generic_cost_per_token(
cast(Optional[int], getattr(usage.prompt_tokens_details, "audio_tokens", 0)) cast(Optional[int], getattr(usage.prompt_tokens_details, "audio_tokens", 0))
or 0 or 0
) )
character_count = (
cast(
Optional[int],
getattr(usage.prompt_tokens_details, "character_count", 0),
)
or 0
)
image_count = (
cast(Optional[int], getattr(usage.prompt_tokens_details, "image_count", 0))
or 0
)
video_length_seconds = (
cast(
Optional[int],
getattr(usage.prompt_tokens_details, "video_length_seconds", 0),
)
or 0
)
## EDGE CASE - text tokens not set inside PromptTokensDetails ## EDGE CASE - text tokens not set inside PromptTokensDetails
if text_tokens == 0: if text_tokens == 0:
@ -172,28 +219,38 @@ def generic_cost_per_token(
prompt_cost = float(text_tokens) * prompt_base_cost prompt_cost = float(text_tokens) * prompt_base_cost
_cache_read_input_token_cost = model_info.get("cache_read_input_token_cost")
### CACHE READ COST ### CACHE READ COST
if ( prompt_cost += calculate_cost_component(
_cache_read_input_token_cost is not None model_info, "cache_read_input_token_cost", cache_hit_tokens
and cache_hit_tokens is not None )
and cache_hit_tokens > 0
):
prompt_cost += float(cache_hit_tokens) * _cache_read_input_token_cost
### AUDIO COST ### AUDIO COST
prompt_cost += calculate_cost_component(
audio_token_cost = model_info.get("input_cost_per_audio_token") model_info, "input_cost_per_audio_token", audio_tokens
if audio_token_cost is not None and audio_tokens is not None and audio_tokens > 0: )
prompt_cost += float(audio_tokens) * audio_token_cost
### CACHE WRITING COST ### CACHE WRITING COST
_cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost") prompt_cost += calculate_cost_component(
if _cache_creation_input_token_cost is not None: model_info,
prompt_cost += ( "cache_creation_input_token_cost",
float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost usage._cache_creation_input_tokens,
) )
### CHARACTER COST
prompt_cost += calculate_cost_component(
model_info, "input_cost_per_character", character_count
)
### IMAGE COUNT COST
prompt_cost += calculate_cost_component(
model_info, "input_cost_per_image", image_count
)
### VIDEO LENGTH COST
prompt_cost += calculate_cost_component(
model_info, "input_cost_per_video_per_second", video_length_seconds
)
## CALCULATE OUTPUT COST ## CALCULATE OUTPUT COST
completion_base_cost = _get_completion_token_base_cost( completion_base_cost = _get_completion_token_base_cost(

View file

@ -66,6 +66,13 @@ class AmazonConverseConfig(BaseConfig):
def custom_llm_provider(self) -> Optional[str]: def custom_llm_provider(self) -> Optional[str]:
return "bedrock_converse" return "bedrock_converse"
@classmethod
def get_config_blocks(cls) -> dict:
return {
"guardrailConfig": GuardrailConfigBlock,
"performanceConfig": PerformanceConfigBlock,
}
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return { return {
@ -394,7 +401,6 @@ class AmazonConverseConfig(BaseConfig):
optional_params: dict, optional_params: dict,
messages: Optional[List[AllMessageValues]] = None, messages: Optional[List[AllMessageValues]] = None,
) -> CommonRequestObject: ) -> CommonRequestObject:
## VALIDATE REQUEST ## VALIDATE REQUEST
""" """
Bedrock doesn't support tool calling without `tools=` param specified. Bedrock doesn't support tool calling without `tools=` param specified.
@ -420,11 +426,11 @@ class AmazonConverseConfig(BaseConfig):
AmazonConverseConfig.__annotations__.keys() AmazonConverseConfig.__annotations__.keys()
) + ["top_k"] ) + ["top_k"]
supported_tool_call_params = ["tools", "tool_choice"] supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"] supported_config_params = list(self.get_config_blocks().keys())
total_supported_params = ( total_supported_params = (
supported_converse_params supported_converse_params
+ supported_tool_call_params + supported_tool_call_params
+ supported_guardrail_params + supported_config_params
) )
inference_params.pop("json_mode", None) # used for handling json_schema inference_params.pop("json_mode", None) # used for handling json_schema
@ -463,12 +469,11 @@ class AmazonConverseConfig(BaseConfig):
), ),
} }
# Guardrail Config # Handle all config blocks
guardrail_config: Optional[GuardrailConfigBlock] = None for config_name, config_class in self.get_config_blocks().items():
request_guardrails_config = inference_params.pop("guardrailConfig", None) config_value = inference_params.pop(config_name, None)
if request_guardrails_config is not None: if config_value is not None:
guardrail_config = GuardrailConfigBlock(**request_guardrails_config) data[config_name] = config_class(**config_value) # type: ignore
data["guardrailConfig"] = guardrail_config
# Tool Config # Tool Config
if bedrock_tool_config is not None: if bedrock_tool_config is not None:

View file

@ -239,6 +239,7 @@ def cost_per_token(
Raises: Raises:
Exception if model requires >128k pricing, but model cost not mapped Exception if model requires >128k pricing, but model cost not mapped
""" """
## GET MODEL INFO ## GET MODEL INFO
model_info = litellm.get_model_info( model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider

View file

@ -1,5 +1,5 @@
import json import json
from typing import List, Literal, Optional, Union from typing import Literal, Optional, Union
import httpx import httpx
@ -14,15 +14,11 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError, VertexAIError,
VertexLLM, VertexLLM,
) )
from litellm.types.llms.vertex_ai import ( from litellm.types.utils import EmbeddingResponse
Instance,
InstanceImage, from .transformation import VertexAIMultimodalEmbeddingConfig
InstanceVideo,
MultimodalPredictions, vertex_multimodal_embedding_handler = VertexAIMultimodalEmbeddingConfig()
VertexMultimodalEmbeddingRequest,
)
from litellm.types.utils import Embedding, EmbeddingResponse
from litellm.utils import is_base64_encoded
class VertexMultimodalEmbedding(VertexLLM): class VertexMultimodalEmbedding(VertexLLM):
@ -41,9 +37,11 @@ class VertexMultimodalEmbedding(VertexLLM):
model_response: EmbeddingResponse, model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"], custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict, optional_params: dict,
litellm_params: dict,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
headers: dict = {},
encoding=None, encoding=None,
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,
@ -86,31 +84,18 @@ class VertexMultimodalEmbedding(VertexLLM):
else: else:
sync_handler = client # type: ignore sync_handler = client # type: ignore
optional_params = optional_params or {} request_data = vertex_multimodal_embedding_handler.transform_embedding_request(
model, input, optional_params, headers
)
request_data = VertexMultimodalEmbeddingRequest() headers = vertex_multimodal_embedding_handler.validate_environment(
headers=headers,
if "instances" in optional_params: model=model,
request_data["instances"] = optional_params["instances"] messages=[],
elif isinstance(input, list): optional_params=optional_params,
vertex_instances: List[Instance] = self.process_openai_embedding_input( api_key=auth_header,
_input=input api_base=api_base,
) )
request_data["instances"] = vertex_instances
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance = self._process_input_element(input)
request_data["instances"] = [vertex_request_instance]
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -132,6 +117,10 @@ class VertexMultimodalEmbedding(VertexLLM):
headers=headers, headers=headers,
client=client, client=client,
model_response=model_response, model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
api_key=api_key,
) )
response = sync_handler.post( response = sync_handler.post(
@ -140,34 +129,30 @@ class VertexMultimodalEmbedding(VertexLLM):
data=json.dumps(request_data), data=json.dumps(request_data),
) )
if response.status_code != 200: return vertex_multimodal_embedding_handler.transform_embedding_response(
raise Exception(f"Error: {response.status_code} {response.text}") model=model,
raw_response=response,
_json_response = response.json() model_response=model_response,
if "predictions" not in _json_response: logging_obj=logging_obj,
raise litellm.InternalServerError( api_key=api_key,
message=f"embedding response does not contain 'predictions', got {_json_response}", request_data=request_data,
llm_provider="vertex_ai", optional_params=optional_params,
model=model, litellm_params=litellm_params,
)
_predictions = _json_response["predictions"]
vertex_predictions = MultimodalPredictions(predictions=_predictions)
model_response.data = self.transform_embedding_response_to_openai(
predictions=vertex_predictions
) )
model_response.model = model
return model_response
async def async_multimodal_embedding( async def async_multimodal_embedding(
self, self,
model: str, model: str,
api_base: str, api_base: str,
data: VertexMultimodalEmbeddingRequest, optional_params: dict,
litellm_params: dict,
data: dict,
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: LiteLLMLoggingObj,
headers={}, headers={},
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
api_key: Optional[str] = None,
) -> litellm.EmbeddingResponse: ) -> litellm.EmbeddingResponse:
if client is None: if client is None:
_params = {} _params = {}
@ -191,112 +176,13 @@ class VertexMultimodalEmbedding(VertexLLM):
except httpx.TimeoutException: except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.") raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json() return vertex_multimodal_embedding_handler.transform_embedding_response(
if "predictions" not in _json_response: model=model,
raise litellm.InternalServerError( raw_response=response,
message=f"embedding response does not contain 'predictions', got {_json_response}", model_response=model_response,
llm_provider="vertex_ai", logging_obj=logging_obj,
model=model, api_key=api_key,
) request_data=data,
_predictions = _json_response["predictions"] optional_params=optional_params,
litellm_params=litellm_params,
vertex_predictions = MultimodalPredictions(predictions=_predictions)
model_response.data = self.transform_embedding_response_to_openai(
predictions=vertex_predictions
) )
model_response.model = model
return model_response
def _process_input_element(self, input_element: str) -> Instance:
"""
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
Args:
input_element (str): The input element to process.
Returns:
Dict[str, Any]: A dictionary representing the processed input element.
"""
if len(input_element) == 0:
return Instance(text=input_element)
elif "gs://" in input_element:
if "mp4" in input_element:
return Instance(video=InstanceVideo(gcsUri=input_element))
else:
return Instance(image=InstanceImage(gcsUri=input_element))
elif is_base64_encoded(s=input_element):
return Instance(
image=InstanceImage(
bytesBase64Encoded=(
input_element.split(",")[1]
if "," in input_element
else input_element
)
)
)
else:
return Instance(text=input_element)
def process_openai_embedding_input(
self, _input: Union[list, str]
) -> List[Instance]:
"""
Process the input for multimodal embedding requests.
Args:
_input (Union[list, str]): The input data to process.
Returns:
List[Instance]: A list of processed VertexAI Instance objects.
"""
_input_list = None
if not isinstance(_input, list):
_input_list = [_input]
else:
_input_list = _input
processed_instances = []
for element in _input_list:
if isinstance(element, str):
instance = Instance(**self._process_input_element(element))
elif isinstance(element, dict):
instance = Instance(**element)
else:
raise ValueError(f"Unsupported input type: {type(element)}")
processed_instances.append(instance)
return processed_instances
def transform_embedding_response_to_openai(
self, predictions: MultimodalPredictions
) -> List[Embedding]:
openai_embeddings: List[Embedding] = []
if "predictions" in predictions:
for idx, _prediction in enumerate(predictions["predictions"]):
if _prediction:
if "textEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["textEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "imageEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["imageEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "videoEmbeddings" in _prediction:
for video_embedding in _prediction["videoEmbeddings"]:
openai_embedding_object = Embedding(
embedding=video_embedding["embedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
return openai_embeddings

View file

@ -0,0 +1,297 @@
from typing import List, Optional, Union, cast
from httpx import Headers, Response
from litellm.exceptions import InternalServerError
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import LiteLLMLoggingObj
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.llms.vertex_ai import (
Instance,
InstanceImage,
InstanceVideo,
MultimodalPredictions,
VertexMultimodalEmbeddingRequest,
)
from litellm.types.utils import (
Embedding,
EmbeddingResponse,
PromptTokensDetailsWrapper,
Usage,
)
from litellm.utils import _count_characters, is_base64_encoded
from ...base_llm.embedding.transformation import BaseEmbeddingConfig
from ..common_utils import VertexAIError
class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
def get_supported_openai_params(self, model: str) -> list:
return ["dimensions"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["outputDimensionality"] = value
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
default_headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
headers.update(default_headers)
return headers
def _process_input_element(self, input_element: str) -> Instance:
"""
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
Args:
input_element (str): The input element to process.
Returns:
Dict[str, Any]: A dictionary representing the processed input element.
"""
if len(input_element) == 0:
return Instance(text=input_element)
elif "gs://" in input_element:
if "mp4" in input_element:
return Instance(video=InstanceVideo(gcsUri=input_element))
else:
return Instance(image=InstanceImage(gcsUri=input_element))
elif is_base64_encoded(s=input_element):
return Instance(
image=InstanceImage(
bytesBase64Encoded=(
input_element.split(",")[1]
if "," in input_element
else input_element
)
)
)
else:
return Instance(text=input_element)
def process_openai_embedding_input(
self, _input: Union[list, str]
) -> List[Instance]:
"""
Process the input for multimodal embedding requests.
Args:
_input (Union[list, str]): The input data to process.
Returns:
Union[Instance, List[Instance]]: Either a single Instance or list of Instance objects.
"""
_input_list = [_input] if not isinstance(_input, list) else _input
processed_instances = []
i = 0
while i < len(_input_list):
current = _input_list[i]
# Look ahead for potential media elements
next_elem = _input_list[i + 1] if i + 1 < len(_input_list) else None
# If current is a text and next is a GCS URI, or current is a GCS URI
if isinstance(current, str):
instance_args: Instance = {}
# Process current element
if "gs://" not in current:
instance_args["text"] = current
elif "mp4" in current:
instance_args["video"] = InstanceVideo(gcsUri=current)
else:
instance_args["image"] = InstanceImage(gcsUri=current)
# Check next element if it's a GCS URI
if next_elem and isinstance(next_elem, str) and "gs://" in next_elem:
if "mp4" in next_elem:
instance_args["video"] = InstanceVideo(gcsUri=next_elem)
else:
instance_args["image"] = InstanceImage(gcsUri=next_elem)
i += 2 # Skip next element since we processed it
else:
i += 1 # Move to next element
processed_instances.append(instance_args)
continue
# Handle dict or other types
if isinstance(current, dict):
instance = Instance(**current)
processed_instances.append(instance)
else:
raise ValueError(f"Unsupported input type: {type(current)}")
i += 1
return processed_instances
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest(instances=[])
if "instances" in optional_params:
request_data["instances"] = optional_params["instances"]
elif isinstance(input, list):
vertex_instances: List[Instance] = self.process_openai_embedding_input(
_input=input
)
request_data["instances"] = vertex_instances
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance = self._process_input_element(input)
request_data["instances"] = [vertex_request_instance]
return cast(dict, request_data)
def transform_embedding_response(
self,
model: str,
raw_response: Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
if raw_response.status_code != 200:
raise Exception(f"Error: {raw_response.status_code} {raw_response.text}")
_json_response = raw_response.json()
if "predictions" not in _json_response:
raise InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
vertex_predictions = MultimodalPredictions(predictions=_predictions)
model_response.data = self.transform_embedding_response_to_openai(
predictions=vertex_predictions
)
model_response.model = model
model_response.usage = self.calculate_usage(
request_data=cast(VertexMultimodalEmbeddingRequest, request_data),
vertex_predictions=vertex_predictions,
)
return model_response
def calculate_usage(
self,
request_data: VertexMultimodalEmbeddingRequest,
vertex_predictions: MultimodalPredictions,
) -> Usage:
## Calculate text embeddings usage
prompt: Optional[str] = None
character_count: Optional[int] = None
for instance in request_data["instances"]:
text = instance.get("text")
if text:
if prompt is None:
prompt = text
else:
prompt += text
if prompt is not None:
character_count = _count_characters(prompt)
## Calculate image embeddings usage
image_count = 0
for instance in request_data["instances"]:
if instance.get("image"):
image_count += 1
## Calculate video embeddings usage
video_length_seconds = 0
for prediction in vertex_predictions["predictions"]:
video_embeddings = prediction.get("videoEmbeddings")
if video_embeddings:
for embedding in video_embeddings:
duration = embedding["endOffsetSec"] - embedding["startOffsetSec"]
video_length_seconds += duration
prompt_tokens_details = PromptTokensDetailsWrapper(
character_count=character_count,
image_count=image_count,
video_length_seconds=video_length_seconds,
)
return Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
prompt_tokens_details=prompt_tokens_details,
)
def transform_embedding_response_to_openai(
self, predictions: MultimodalPredictions
) -> List[Embedding]:
openai_embeddings: List[Embedding] = []
if "predictions" in predictions:
for idx, _prediction in enumerate(predictions["predictions"]):
if _prediction:
if "textEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["textEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "imageEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["imageEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "videoEmbeddings" in _prediction:
for video_embedding in _prediction["videoEmbeddings"]:
openai_embedding_object = Embedding(
embedding=video_embedding["embedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
return openai_embeddings
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return VertexAIError(
status_code=status_code, message=error_message, headers=headers
)

View file

@ -3723,6 +3723,7 @@ def embedding( # noqa: PLR0915
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params_dict,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project, vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,

View file

@ -5409,6 +5409,23 @@
"supported_modalities": ["text", "image", "video"], "supported_modalities": ["text", "image", "video"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
}, },
"multimodalembedding@001": {
"max_tokens": 2048,
"max_input_tokens": 2048,
"output_vector_size": 768,
"input_cost_per_character": 0.0000002,
"input_cost_per_image": 0.0001,
"input_cost_per_video_per_second": 0.0005,
"input_cost_per_video_per_second_above_8s_interval": 0.0010,
"input_cost_per_video_per_second_above_15s_interval": 0.0020,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding",
"supported_endpoints": ["/v1/embeddings"],
"supported_modalities": ["text", "image", "video"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
},
"text-embedding-large-exp-03-07": { "text-embedding-large-exp-03-07": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, cast, get_type_hints from typing import Any, Dict, Union, cast, get_type_hints
import litellm import litellm
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
@ -78,16 +78,22 @@ class ResponsesAPIRequestUtils:
class ResponseAPILoggingUtils: class ResponseAPILoggingUtils:
@staticmethod @staticmethod
def _is_response_api_usage(usage: dict) -> bool: def _is_response_api_usage(usage: Union[dict, ResponseAPIUsage]) -> bool:
"""returns True if usage is from OpenAI Response API""" """returns True if usage is from OpenAI Response API"""
if isinstance(usage, ResponseAPIUsage):
return True
if "input_tokens" in usage and "output_tokens" in usage: if "input_tokens" in usage and "output_tokens" in usage:
return True return True
return False return False
@staticmethod @staticmethod
def _transform_response_api_usage_to_chat_usage(usage: dict) -> Usage: def _transform_response_api_usage_to_chat_usage(
usage: Union[dict, ResponseAPIUsage]
) -> Usage:
"""Tranforms the ResponseAPIUsage object to a Usage object""" """Tranforms the ResponseAPIUsage object to a Usage object"""
response_api_usage: ResponseAPIUsage = ResponseAPIUsage(**usage) response_api_usage: ResponseAPIUsage = (
ResponseAPIUsage(**usage) if isinstance(usage, dict) else usage
)
prompt_tokens: int = response_api_usage.input_tokens or 0 prompt_tokens: int = response_api_usage.input_tokens or 0
completion_tokens: int = response_api_usage.output_tokens or 0 completion_tokens: int = response_api_usage.output_tokens or 0
return Usage( return Usage(

View file

@ -191,6 +191,10 @@ class ContentBlockDeltaEvent(TypedDict, total=False):
reasoningContent: BedrockConverseReasoningContentBlockDelta reasoningContent: BedrockConverseReasoningContentBlockDelta
class PerformanceConfigBlock(TypedDict):
latency: Literal["optimized", "throughput"]
class CommonRequestObject( class CommonRequestObject(
TypedDict, total=False TypedDict, total=False
): # common request object across sync + async flows ): # common request object across sync + async flows
@ -200,6 +204,7 @@ class CommonRequestObject(
system: List[SystemContentBlock] system: List[SystemContentBlock]
toolConfig: ToolConfigBlock toolConfig: ToolConfigBlock
guardrailConfig: Optional[GuardrailConfigBlock] guardrailConfig: Optional[GuardrailConfigBlock]
performanceConfig: Optional[PerformanceConfigBlock]
class RequestObject(CommonRequestObject, total=False): class RequestObject(CommonRequestObject, total=False):
@ -429,7 +434,9 @@ class AmazonNovaCanvasColorGuidedGenerationParams(TypedDict, total=False):
negativeText: str negativeText: str
class AmazonNovaCanvasColorGuidedRequest(AmazonNovaCanvasRequestBase, TypedDict, total=False): class AmazonNovaCanvasColorGuidedRequest(
AmazonNovaCanvasRequestBase, TypedDict, total=False
):
""" """
Request for Amazon Nova Canvas Color Guided Generation API Request for Amazon Nova Canvas Color Guided Generation API

View file

@ -369,9 +369,15 @@ class ResponseTuningJob(TypedDict):
updateTime: Optional[str] updateTime: Optional[str]
class VideoSegmentConfig(TypedDict, total=False):
startOffsetSec: int
endOffsetSec: int
intervalSec: int
class InstanceVideo(TypedDict, total=False): class InstanceVideo(TypedDict, total=False):
gcsUri: str gcsUri: str
videoSegmentConfig: Tuple[float, float, float] videoSegmentConfig: VideoSegmentConfig
class InstanceImage(TypedDict, total=False): class InstanceImage(TypedDict, total=False):
@ -386,7 +392,7 @@ class Instance(TypedDict, total=False):
video: InstanceVideo video: InstanceVideo
class VertexMultimodalEmbeddingRequest(TypedDict, total=False): class VertexMultimodalEmbeddingRequest(TypedDict):
instances: List[Instance] instances: List[Instance]
@ -402,7 +408,7 @@ class MultimodalPrediction(TypedDict, total=False):
videoEmbeddings: List[VideoEmbedding] videoEmbeddings: List[VideoEmbedding]
class MultimodalPredictions(TypedDict, total=False): class MultimodalPredictions(TypedDict):
predictions: List[MultimodalPrediction] predictions: List[MultimodalPrediction]

View file

@ -779,6 +779,24 @@ class PromptTokensDetailsWrapper(
image_tokens: Optional[int] = None image_tokens: Optional[int] = None
"""Image tokens sent to the model.""" """Image tokens sent to the model."""
character_count: Optional[int] = None
"""Character count sent to the model. Used for Vertex AI multimodal embeddings."""
image_count: Optional[int] = None
"""Number of images sent to the model. Used for Vertex AI multimodal embeddings."""
video_length_seconds: Optional[float] = None
"""Length of videos sent to the model. Used for Vertex AI multimodal embeddings."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.character_count is None:
del self.character_count
if self.image_count is None:
del self.image_count
if self.video_length_seconds is None:
del self.video_length_seconds
class Usage(CompletionUsage): class Usage(CompletionUsage):
_cache_creation_input_tokens: int = PrivateAttr( _cache_creation_input_tokens: int = PrivateAttr(

View file

@ -5409,6 +5409,23 @@
"supported_modalities": ["text", "image", "video"], "supported_modalities": ["text", "image", "video"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
}, },
"multimodalembedding@001": {
"max_tokens": 2048,
"max_input_tokens": 2048,
"output_vector_size": 768,
"input_cost_per_character": 0.0000002,
"input_cost_per_image": 0.0001,
"input_cost_per_video_per_second": 0.0005,
"input_cost_per_video_per_second_above_8s_interval": 0.0010,
"input_cost_per_video_per_second_above_15s_interval": 0.0020,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding",
"supported_endpoints": ["/v1/embeddings"],
"supported_modalities": ["text", "image", "video"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
},
"text-embedding-large-exp-03-07": { "text-embedding-large-exp-03-07": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,

View file

@ -1,43 +0,0 @@
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
from litellm.llms.vertex_ai.multimodal_embeddings.embedding_handler import (
VertexMultimodalEmbedding,
)
from litellm.types.llms.vertex_ai import Instance, InstanceImage
class TestVertexMultimodalEmbedding:
def setup_method(self):
self.embedding_handler = VertexMultimodalEmbedding()
def test_process_openai_embedding_input(self):
input_data = [
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
]
expected_output = [
Instance(
image=InstanceImage(bytesBase64Encoded=input_data[0].split(",")[1])
),
Instance(
image=InstanceImage(bytesBase64Encoded=input_data[1].split(",")[1])
),
]
assert (
self.embedding_handler._process_input_element(input_data[0])
== expected_output[0]
)
assert (
self.embedding_handler._process_input_element(input_data[1])
== expected_output[1]
)

View file

@ -0,0 +1,78 @@
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
from litellm.llms.vertex_ai.multimodal_embeddings.transformation import (
VertexAIMultimodalEmbeddingConfig,
)
from litellm.types.llms.vertex_ai import Instance, InstanceImage, InstanceVideo
class TestVertexMultimodalEmbedding:
def setup_method(self):
self.config = VertexAIMultimodalEmbeddingConfig()
def test_process_openai_embedding_input(self):
input_data = [
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
]
expected_output = [
Instance(
image=InstanceImage(bytesBase64Encoded=input_data[0].split(",")[1])
),
Instance(
image=InstanceImage(bytesBase64Encoded=input_data[1].split(",")[1])
),
]
assert self.config._process_input_element(input_data[0]) == expected_output[0]
assert self.config._process_input_element(input_data[1]) == expected_output[1]
def test_process_str_and_video_input(self):
input_data = ["hi", "gs://my-bucket/embeddings/supermarket-video.mp4"]
expected_output = [
Instance(
text="hi",
video=InstanceVideo(
gcsUri="gs://my-bucket/embeddings/supermarket-video.mp4"
),
),
]
assert self.config.process_openai_embedding_input(input_data) == expected_output
def test_process_list_of_str_and_str_input(self):
input_data = ["hi", "hello"]
expected_output = [
Instance(text="hi"),
Instance(text="hello"),
]
assert self.config.process_openai_embedding_input(input_data) == expected_output
def test_process_list_of_str_and_video_input(self):
input_data = [
"hi",
"hello",
"gs://my-bucket/embeddings/supermarket-video.mp4",
"hey",
]
expected_output = [
Instance(text="hi"),
Instance(
text="hello",
video=InstanceVideo(
gcsUri="gs://my-bucket/embeddings/supermarket-video.mp4"
),
),
Instance(text="hey"),
]
assert (
self.config.process_openai_embedding_input(input_data) == expected_output
), f"Expected {expected_output}, but got {self.config.process_openai_embedding_input(input_data)}"

View file

@ -239,3 +239,23 @@ async def test_url_with_format_param_openai(model, sync_mode):
json_str = json.dumps(mock_client.call_args.kwargs) json_str = json.dumps(mock_client.call_args.kwargs)
assert "format" not in json_str assert "format" not in json_str
def test_bedrock_latency_optimized_inference():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
response = litellm.completion(
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
messages=[{"role": "user", "content": "Hello, how are you?"}],
performanceConfig={"latency": "optimized"},
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert json_data["performanceConfig"]["latency"] == "optimized"

View file

@ -416,6 +416,7 @@ def validate_web_search_annotations(annotations: ChatCompletionAnnotation):
validate_response_url_citation(url_citation) validate_response_url_citation(url_citation)
@pytest.mark.flaky(reruns=3)
def test_openai_web_search(): def test_openai_web_search():
"""Makes a simple web search request and validates the response contains web search annotations and all expected fields are present""" """Makes a simple web search request and validates the response contains web search annotations and all expected fields are present"""
litellm._turn_on_debug() litellm._turn_on_debug()

View file

@ -2197,6 +2197,55 @@ def test_vertexai_embedding_embedding_latest():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_vertexai_multimodalembedding_embedding_latest():
try:
import requests, base64
load_vertex_ai_credentials()
litellm._turn_on_debug()
response = embedding(
model="vertex_ai/multimodalembedding@001",
input=["hi"],
dimensions=1,
auto_truncate=True,
task_type="RETRIEVAL_QUERY",
)
print(f"response.usage: {response.usage}")
assert response.usage is not None
assert response.usage.prompt_tokens_details is not None
assert response._hidden_params["response_cost"] > 0
print(f"response:", response)
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_vertexai_embedding_embedding_latest():
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
response = embedding(
model="vertex_ai/text-embedding-004",
input=["hi"],
dimensions=1,
auto_truncate=True,
task_type="RETRIEVAL_QUERY",
)
assert len(response.data[0]["embedding"]) == 1
assert response.usage.prompt_tokens > 0
print(f"response:", response)
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="need to get gecko permissions on vertex ai to run this test") @pytest.mark.skip(reason="need to get gecko permissions on vertex ai to run this test")
@pytest.mark.flaky(retries=3, delay=1) @pytest.mark.flaky(retries=3, delay=1)
def test_vertexai_embedding_embedding_latest_input_type(): def test_vertexai_embedding_embedding_latest_input_type():