mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Add bedrock latency optimized inference support (#9623)
* fix(converse_transformation.py): add performanceConfig param support on bedrock Closes https://github.com/BerriAI/litellm/issues/7606 * fix(converse_transformation.py): refactor to use more flexible single getter for params which are separate config blocks * test(test_main.py): add e2e mock test for bedrock performance config * build(model_prices_and_context_window.json): add versioned multimodal embedding * refactor(multimodal_embeddings/): migrate to config pattern * feat(vertex_ai/multimodalembeddings): calculate usage for multimodal embedding calls Enables cost calculation for multimodal embeddings * feat(vertex_ai/multimodalembeddings): get usage object for embedding calls ensures accurate cost tracking for vertexai multimodal embedding calls * fix(embedding_handler.py): remove unused imports * fix: fix linting errors * fix: handle response api usage calculation * test(test_vertex_ai_multimodal_embedding_transformation.py): update tests * test: mark flaky test * feat(vertex_ai/multimodal_embeddings/transformation.py): support text+image+video input * docs(vertex.md): document sending text + image to vertex multimodal embeddings * test: remove incorrect file * fix(multimodal_embeddings/transformation.py): fix linting error * style: remove unused import
This commit is contained in:
parent
0742e6afd6
commit
5ac61a7572
19 changed files with 806 additions and 245 deletions
|
@ -2080,7 +2080,12 @@ print(response)
|
||||||
|
|
||||||
## **Multi-Modal Embeddings**
|
## **Multi-Modal Embeddings**
|
||||||
|
|
||||||
Usage
|
|
||||||
|
Known Limitations:
|
||||||
|
- Only supports 1 image / video / image per request
|
||||||
|
- Only supports GCS or base64 encoded images / videos
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="sdk" label="SDK">
|
<TabItem value="sdk" label="SDK">
|
||||||
|
@ -2296,6 +2301,115 @@ print(f"Text Embedding: {embeddings.text_embedding}")
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
### Text + Image + Video Embeddings
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
Text + Image
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=["hey", "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"] # will be sent as a gcs image
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Text + Video
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=["hey", "gs://my-bucket/embeddings/supermarket-video.mp4"] # will be sent as a gcs image
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Image + Video
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=["gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png", "gs://my-bucket/embeddings/supermarket-video.mp4"] # will be sent as a gcs image
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="LiteLLM PROXY (Unified Endpoint)">
|
||||||
|
|
||||||
|
1. Add model to config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: multimodalembedding@001
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/multimodalembedding@001
|
||||||
|
vertex_project: "adroit-crow-413218"
|
||||||
|
vertex_location: "us-central1"
|
||||||
|
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: True
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request use OpenAI Python SDK, Langchain Python SDK
|
||||||
|
|
||||||
|
|
||||||
|
Text + Image
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
# # request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="multimodalembedding@001",
|
||||||
|
input = ["hey", "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
Text + Video
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
# # request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="multimodalembedding@001",
|
||||||
|
input = ["hey", "gs://my-bucket/embeddings/supermarket-video.mp4"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
Image + Video
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
# # request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="multimodalembedding@001",
|
||||||
|
input = ["gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png", "gs://my-bucket/embeddings/supermarket-video.mp4"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
## **Image Generation Models**
|
## **Image Generation Models**
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
## File for 'response_cost' calculation in Logging
|
## File for 'response_cost' calculation in Logging
|
||||||
import time
|
import time
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
from typing import Any, List, Literal, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -462,13 +462,36 @@ def _model_contains_known_llm_provider(model: str) -> bool:
|
||||||
def _get_usage_object(
|
def _get_usage_object(
|
||||||
completion_response: Any,
|
completion_response: Any,
|
||||||
) -> Optional[Usage]:
|
) -> Optional[Usage]:
|
||||||
usage_obj: Optional[Usage] = None
|
usage_obj = cast(
|
||||||
if completion_response is not None and isinstance(
|
Union[Usage, ResponseAPIUsage, dict, BaseModel],
|
||||||
completion_response, ModelResponse
|
(
|
||||||
):
|
completion_response.get("usage")
|
||||||
usage_obj = completion_response.get("usage")
|
if isinstance(completion_response, dict)
|
||||||
|
else getattr(completion_response, "get", lambda x: None)("usage")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if usage_obj is None:
|
||||||
|
return None
|
||||||
|
if isinstance(usage_obj, Usage):
|
||||||
return usage_obj
|
return usage_obj
|
||||||
|
elif (
|
||||||
|
usage_obj is not None
|
||||||
|
and (isinstance(usage_obj, dict) or isinstance(usage_obj, ResponseAPIUsage))
|
||||||
|
and ResponseAPILoggingUtils._is_response_api_usage(usage_obj)
|
||||||
|
):
|
||||||
|
return ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||||
|
usage_obj
|
||||||
|
)
|
||||||
|
elif isinstance(usage_obj, dict):
|
||||||
|
return Usage(**usage_obj)
|
||||||
|
elif isinstance(usage_obj, BaseModel):
|
||||||
|
return Usage(**usage_obj.model_dump())
|
||||||
|
else:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Unknown usage object type: {type(usage_obj)}, usage_obj: {usage_obj}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _is_known_usage_objects(usage_obj):
|
def _is_known_usage_objects(usage_obj):
|
||||||
|
@ -662,6 +685,7 @@ def completion_cost( # noqa: PLR0915
|
||||||
elif len(prompt) > 0:
|
elif len(prompt) > 0:
|
||||||
prompt_tokens = token_counter(model=model, text=prompt)
|
prompt_tokens = token_counter(model=model, text=prompt)
|
||||||
completion_tokens = token_counter(model=model, text=completion)
|
completion_tokens = token_counter(model=model, text=completion)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
|
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
|
||||||
|
|
|
@ -121,6 +121,31 @@ def _get_completion_token_base_cost(model_info: ModelInfo, usage: Usage) -> floa
|
||||||
return model_info["output_cost_per_token"]
|
return model_info["output_cost_per_token"]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cost_component(
|
||||||
|
model_info: ModelInfo, cost_key: str, usage_value: Optional[float]
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Generic cost calculator for any usage component
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_info: Dictionary containing cost information
|
||||||
|
cost_key: The key for the cost multiplier in model_info (e.g., 'input_cost_per_audio_token')
|
||||||
|
usage_value: The actual usage value (e.g., number of tokens, characters, seconds)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The calculated cost
|
||||||
|
"""
|
||||||
|
cost_per_unit = model_info.get(cost_key)
|
||||||
|
if (
|
||||||
|
cost_per_unit is not None
|
||||||
|
and isinstance(cost_per_unit, float)
|
||||||
|
and usage_value is not None
|
||||||
|
and usage_value > 0
|
||||||
|
):
|
||||||
|
return float(usage_value) * cost_per_unit
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
def generic_cost_per_token(
|
def generic_cost_per_token(
|
||||||
model: str, usage: Usage, custom_llm_provider: str
|
model: str, usage: Usage, custom_llm_provider: str
|
||||||
) -> Tuple[float, float]:
|
) -> Tuple[float, float]:
|
||||||
|
@ -136,6 +161,7 @@ def generic_cost_per_token(
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||||
"""
|
"""
|
||||||
|
|
||||||
## GET MODEL INFO
|
## GET MODEL INFO
|
||||||
model_info = get_model_info(model=model, custom_llm_provider=custom_llm_provider)
|
model_info = get_model_info(model=model, custom_llm_provider=custom_llm_provider)
|
||||||
|
|
||||||
|
@ -146,6 +172,9 @@ def generic_cost_per_token(
|
||||||
text_tokens = usage.prompt_tokens
|
text_tokens = usage.prompt_tokens
|
||||||
cache_hit_tokens = 0
|
cache_hit_tokens = 0
|
||||||
audio_tokens = 0
|
audio_tokens = 0
|
||||||
|
character_count = 0
|
||||||
|
image_count = 0
|
||||||
|
video_length_seconds = 0
|
||||||
if usage.prompt_tokens_details:
|
if usage.prompt_tokens_details:
|
||||||
cache_hit_tokens = (
|
cache_hit_tokens = (
|
||||||
cast(
|
cast(
|
||||||
|
@ -163,6 +192,24 @@ def generic_cost_per_token(
|
||||||
cast(Optional[int], getattr(usage.prompt_tokens_details, "audio_tokens", 0))
|
cast(Optional[int], getattr(usage.prompt_tokens_details, "audio_tokens", 0))
|
||||||
or 0
|
or 0
|
||||||
)
|
)
|
||||||
|
character_count = (
|
||||||
|
cast(
|
||||||
|
Optional[int],
|
||||||
|
getattr(usage.prompt_tokens_details, "character_count", 0),
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
image_count = (
|
||||||
|
cast(Optional[int], getattr(usage.prompt_tokens_details, "image_count", 0))
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
video_length_seconds = (
|
||||||
|
cast(
|
||||||
|
Optional[int],
|
||||||
|
getattr(usage.prompt_tokens_details, "video_length_seconds", 0),
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
## EDGE CASE - text tokens not set inside PromptTokensDetails
|
## EDGE CASE - text tokens not set inside PromptTokensDetails
|
||||||
if text_tokens == 0:
|
if text_tokens == 0:
|
||||||
|
@ -172,27 +219,37 @@ def generic_cost_per_token(
|
||||||
|
|
||||||
prompt_cost = float(text_tokens) * prompt_base_cost
|
prompt_cost = float(text_tokens) * prompt_base_cost
|
||||||
|
|
||||||
_cache_read_input_token_cost = model_info.get("cache_read_input_token_cost")
|
|
||||||
|
|
||||||
### CACHE READ COST
|
### CACHE READ COST
|
||||||
if (
|
prompt_cost += calculate_cost_component(
|
||||||
_cache_read_input_token_cost is not None
|
model_info, "cache_read_input_token_cost", cache_hit_tokens
|
||||||
and cache_hit_tokens is not None
|
)
|
||||||
and cache_hit_tokens > 0
|
|
||||||
):
|
|
||||||
prompt_cost += float(cache_hit_tokens) * _cache_read_input_token_cost
|
|
||||||
|
|
||||||
### AUDIO COST
|
### AUDIO COST
|
||||||
|
prompt_cost += calculate_cost_component(
|
||||||
audio_token_cost = model_info.get("input_cost_per_audio_token")
|
model_info, "input_cost_per_audio_token", audio_tokens
|
||||||
if audio_token_cost is not None and audio_tokens is not None and audio_tokens > 0:
|
)
|
||||||
prompt_cost += float(audio_tokens) * audio_token_cost
|
|
||||||
|
|
||||||
### CACHE WRITING COST
|
### CACHE WRITING COST
|
||||||
_cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost")
|
prompt_cost += calculate_cost_component(
|
||||||
if _cache_creation_input_token_cost is not None:
|
model_info,
|
||||||
prompt_cost += (
|
"cache_creation_input_token_cost",
|
||||||
float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost
|
usage._cache_creation_input_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
### CHARACTER COST
|
||||||
|
|
||||||
|
prompt_cost += calculate_cost_component(
|
||||||
|
model_info, "input_cost_per_character", character_count
|
||||||
|
)
|
||||||
|
|
||||||
|
### IMAGE COUNT COST
|
||||||
|
prompt_cost += calculate_cost_component(
|
||||||
|
model_info, "input_cost_per_image", image_count
|
||||||
|
)
|
||||||
|
|
||||||
|
### VIDEO LENGTH COST
|
||||||
|
prompt_cost += calculate_cost_component(
|
||||||
|
model_info, "input_cost_per_video_per_second", video_length_seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
## CALCULATE OUTPUT COST
|
## CALCULATE OUTPUT COST
|
||||||
|
|
|
@ -66,6 +66,13 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
def custom_llm_provider(self) -> Optional[str]:
|
def custom_llm_provider(self) -> Optional[str]:
|
||||||
return "bedrock_converse"
|
return "bedrock_converse"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_blocks(cls) -> dict:
|
||||||
|
return {
|
||||||
|
"guardrailConfig": GuardrailConfigBlock,
|
||||||
|
"performanceConfig": PerformanceConfigBlock,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return {
|
return {
|
||||||
|
@ -394,7 +401,6 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
messages: Optional[List[AllMessageValues]] = None,
|
messages: Optional[List[AllMessageValues]] = None,
|
||||||
) -> CommonRequestObject:
|
) -> CommonRequestObject:
|
||||||
|
|
||||||
## VALIDATE REQUEST
|
## VALIDATE REQUEST
|
||||||
"""
|
"""
|
||||||
Bedrock doesn't support tool calling without `tools=` param specified.
|
Bedrock doesn't support tool calling without `tools=` param specified.
|
||||||
|
@ -420,11 +426,11 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
AmazonConverseConfig.__annotations__.keys()
|
AmazonConverseConfig.__annotations__.keys()
|
||||||
) + ["top_k"]
|
) + ["top_k"]
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
supported_guardrail_params = ["guardrailConfig"]
|
supported_config_params = list(self.get_config_blocks().keys())
|
||||||
total_supported_params = (
|
total_supported_params = (
|
||||||
supported_converse_params
|
supported_converse_params
|
||||||
+ supported_tool_call_params
|
+ supported_tool_call_params
|
||||||
+ supported_guardrail_params
|
+ supported_config_params
|
||||||
)
|
)
|
||||||
inference_params.pop("json_mode", None) # used for handling json_schema
|
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||||
|
|
||||||
|
@ -463,12 +469,11 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Guardrail Config
|
# Handle all config blocks
|
||||||
guardrail_config: Optional[GuardrailConfigBlock] = None
|
for config_name, config_class in self.get_config_blocks().items():
|
||||||
request_guardrails_config = inference_params.pop("guardrailConfig", None)
|
config_value = inference_params.pop(config_name, None)
|
||||||
if request_guardrails_config is not None:
|
if config_value is not None:
|
||||||
guardrail_config = GuardrailConfigBlock(**request_guardrails_config)
|
data[config_name] = config_class(**config_value) # type: ignore
|
||||||
data["guardrailConfig"] = guardrail_config
|
|
||||||
|
|
||||||
# Tool Config
|
# Tool Config
|
||||||
if bedrock_tool_config is not None:
|
if bedrock_tool_config is not None:
|
||||||
|
|
|
@ -239,6 +239,7 @@ def cost_per_token(
|
||||||
Raises:
|
Raises:
|
||||||
Exception if model requires >128k pricing, but model cost not mapped
|
Exception if model requires >128k pricing, but model cost not mapped
|
||||||
"""
|
"""
|
||||||
|
|
||||||
## GET MODEL INFO
|
## GET MODEL INFO
|
||||||
model_info = litellm.get_model_info(
|
model_info = litellm.get_model_info(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -14,15 +14,11 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexAIError,
|
VertexAIError,
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.vertex_ai import (
|
from litellm.types.utils import EmbeddingResponse
|
||||||
Instance,
|
|
||||||
InstanceImage,
|
from .transformation import VertexAIMultimodalEmbeddingConfig
|
||||||
InstanceVideo,
|
|
||||||
MultimodalPredictions,
|
vertex_multimodal_embedding_handler = VertexAIMultimodalEmbeddingConfig()
|
||||||
VertexMultimodalEmbeddingRequest,
|
|
||||||
)
|
|
||||||
from litellm.types.utils import Embedding, EmbeddingResponse
|
|
||||||
from litellm.utils import is_base64_encoded
|
|
||||||
|
|
||||||
|
|
||||||
class VertexMultimodalEmbedding(VertexLLM):
|
class VertexMultimodalEmbedding(VertexLLM):
|
||||||
|
@ -41,9 +37,11 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
|
headers: dict = {},
|
||||||
encoding=None,
|
encoding=None,
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
|
@ -86,31 +84,18 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
else:
|
else:
|
||||||
sync_handler = client # type: ignore
|
sync_handler = client # type: ignore
|
||||||
|
|
||||||
optional_params = optional_params or {}
|
request_data = vertex_multimodal_embedding_handler.transform_embedding_request(
|
||||||
|
model, input, optional_params, headers
|
||||||
request_data = VertexMultimodalEmbeddingRequest()
|
|
||||||
|
|
||||||
if "instances" in optional_params:
|
|
||||||
request_data["instances"] = optional_params["instances"]
|
|
||||||
elif isinstance(input, list):
|
|
||||||
vertex_instances: List[Instance] = self.process_openai_embedding_input(
|
|
||||||
_input=input
|
|
||||||
)
|
)
|
||||||
request_data["instances"] = vertex_instances
|
|
||||||
|
|
||||||
else:
|
headers = vertex_multimodal_embedding_handler.validate_environment(
|
||||||
# construct instances
|
headers=headers,
|
||||||
vertex_request_instance = Instance(**optional_params)
|
model=model,
|
||||||
|
messages=[],
|
||||||
if isinstance(input, str):
|
optional_params=optional_params,
|
||||||
vertex_request_instance = self._process_input_element(input)
|
api_key=auth_header,
|
||||||
|
api_base=api_base,
|
||||||
request_data["instances"] = [vertex_request_instance]
|
)
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
|
||||||
"Authorization": f"Bearer {auth_header}",
|
|
||||||
}
|
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -132,6 +117,10 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
client=client,
|
client=client,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = sync_handler.post(
|
response = sync_handler.post(
|
||||||
|
@ -140,34 +129,30 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
data=json.dumps(request_data),
|
data=json.dumps(request_data),
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
return vertex_multimodal_embedding_handler.transform_embedding_response(
|
||||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
|
||||||
|
|
||||||
_json_response = response.json()
|
|
||||||
if "predictions" not in _json_response:
|
|
||||||
raise litellm.InternalServerError(
|
|
||||||
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
|
||||||
llm_provider="vertex_ai",
|
|
||||||
model=model,
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=request_data,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
_predictions = _json_response["predictions"]
|
|
||||||
vertex_predictions = MultimodalPredictions(predictions=_predictions)
|
|
||||||
model_response.data = self.transform_embedding_response_to_openai(
|
|
||||||
predictions=vertex_predictions
|
|
||||||
)
|
|
||||||
model_response.model = model
|
|
||||||
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
async def async_multimodal_embedding(
|
async def async_multimodal_embedding(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
data: VertexMultimodalEmbeddingRequest,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
data: dict,
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
headers={},
|
headers={},
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
) -> litellm.EmbeddingResponse:
|
) -> litellm.EmbeddingResponse:
|
||||||
if client is None:
|
if client is None:
|
||||||
_params = {}
|
_params = {}
|
||||||
|
@ -191,112 +176,13 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
_json_response = response.json()
|
return vertex_multimodal_embedding_handler.transform_embedding_response(
|
||||||
if "predictions" not in _json_response:
|
|
||||||
raise litellm.InternalServerError(
|
|
||||||
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
|
||||||
llm_provider="vertex_ai",
|
|
||||||
model=model,
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=data,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
_predictions = _json_response["predictions"]
|
|
||||||
|
|
||||||
vertex_predictions = MultimodalPredictions(predictions=_predictions)
|
|
||||||
model_response.data = self.transform_embedding_response_to_openai(
|
|
||||||
predictions=vertex_predictions
|
|
||||||
)
|
|
||||||
model_response.model = model
|
|
||||||
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
def _process_input_element(self, input_element: str) -> Instance:
|
|
||||||
"""
|
|
||||||
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_element (str): The input element to process.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: A dictionary representing the processed input element.
|
|
||||||
"""
|
|
||||||
if len(input_element) == 0:
|
|
||||||
return Instance(text=input_element)
|
|
||||||
elif "gs://" in input_element:
|
|
||||||
if "mp4" in input_element:
|
|
||||||
return Instance(video=InstanceVideo(gcsUri=input_element))
|
|
||||||
else:
|
|
||||||
return Instance(image=InstanceImage(gcsUri=input_element))
|
|
||||||
elif is_base64_encoded(s=input_element):
|
|
||||||
return Instance(
|
|
||||||
image=InstanceImage(
|
|
||||||
bytesBase64Encoded=(
|
|
||||||
input_element.split(",")[1]
|
|
||||||
if "," in input_element
|
|
||||||
else input_element
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return Instance(text=input_element)
|
|
||||||
|
|
||||||
def process_openai_embedding_input(
|
|
||||||
self, _input: Union[list, str]
|
|
||||||
) -> List[Instance]:
|
|
||||||
"""
|
|
||||||
Process the input for multimodal embedding requests.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
_input (Union[list, str]): The input data to process.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Instance]: A list of processed VertexAI Instance objects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_input_list = None
|
|
||||||
if not isinstance(_input, list):
|
|
||||||
_input_list = [_input]
|
|
||||||
else:
|
|
||||||
_input_list = _input
|
|
||||||
|
|
||||||
processed_instances = []
|
|
||||||
for element in _input_list:
|
|
||||||
if isinstance(element, str):
|
|
||||||
instance = Instance(**self._process_input_element(element))
|
|
||||||
elif isinstance(element, dict):
|
|
||||||
instance = Instance(**element)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported input type: {type(element)}")
|
|
||||||
processed_instances.append(instance)
|
|
||||||
|
|
||||||
return processed_instances
|
|
||||||
|
|
||||||
def transform_embedding_response_to_openai(
|
|
||||||
self, predictions: MultimodalPredictions
|
|
||||||
) -> List[Embedding]:
|
|
||||||
|
|
||||||
openai_embeddings: List[Embedding] = []
|
|
||||||
if "predictions" in predictions:
|
|
||||||
for idx, _prediction in enumerate(predictions["predictions"]):
|
|
||||||
if _prediction:
|
|
||||||
if "textEmbedding" in _prediction:
|
|
||||||
openai_embedding_object = Embedding(
|
|
||||||
embedding=_prediction["textEmbedding"],
|
|
||||||
index=idx,
|
|
||||||
object="embedding",
|
|
||||||
)
|
|
||||||
openai_embeddings.append(openai_embedding_object)
|
|
||||||
elif "imageEmbedding" in _prediction:
|
|
||||||
openai_embedding_object = Embedding(
|
|
||||||
embedding=_prediction["imageEmbedding"],
|
|
||||||
index=idx,
|
|
||||||
object="embedding",
|
|
||||||
)
|
|
||||||
openai_embeddings.append(openai_embedding_object)
|
|
||||||
elif "videoEmbeddings" in _prediction:
|
|
||||||
for video_embedding in _prediction["videoEmbeddings"]:
|
|
||||||
openai_embedding_object = Embedding(
|
|
||||||
embedding=video_embedding["embedding"],
|
|
||||||
index=idx,
|
|
||||||
object="embedding",
|
|
||||||
)
|
|
||||||
openai_embeddings.append(openai_embedding_object)
|
|
||||||
return openai_embeddings
|
|
||||||
|
|
297
litellm/llms/vertex_ai/multimodal_embeddings/transformation.py
Normal file
297
litellm/llms/vertex_ai/multimodal_embeddings/transformation.py
Normal file
|
@ -0,0 +1,297 @@
|
||||||
|
from typing import List, Optional, Union, cast
|
||||||
|
|
||||||
|
from httpx import Headers, Response
|
||||||
|
|
||||||
|
from litellm.exceptions import InternalServerError
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.llms.base_llm.embedding.transformation import LiteLLMLoggingObj
|
||||||
|
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
Instance,
|
||||||
|
InstanceImage,
|
||||||
|
InstanceVideo,
|
||||||
|
MultimodalPredictions,
|
||||||
|
VertexMultimodalEmbeddingRequest,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import (
|
||||||
|
Embedding,
|
||||||
|
EmbeddingResponse,
|
||||||
|
PromptTokensDetailsWrapper,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
from litellm.utils import _count_characters, is_base64_encoded
|
||||||
|
|
||||||
|
from ...base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||||
|
from ..common_utils import VertexAIError
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
return ["dimensions"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "dimensions":
|
||||||
|
optional_params["outputDimensionality"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
default_headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
headers.update(default_headers)
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _process_input_element(self, input_element: str) -> Instance:
|
||||||
|
"""
|
||||||
|
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_element (str): The input element to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary representing the processed input element.
|
||||||
|
"""
|
||||||
|
if len(input_element) == 0:
|
||||||
|
return Instance(text=input_element)
|
||||||
|
elif "gs://" in input_element:
|
||||||
|
if "mp4" in input_element:
|
||||||
|
return Instance(video=InstanceVideo(gcsUri=input_element))
|
||||||
|
else:
|
||||||
|
return Instance(image=InstanceImage(gcsUri=input_element))
|
||||||
|
elif is_base64_encoded(s=input_element):
|
||||||
|
return Instance(
|
||||||
|
image=InstanceImage(
|
||||||
|
bytesBase64Encoded=(
|
||||||
|
input_element.split(",")[1]
|
||||||
|
if "," in input_element
|
||||||
|
else input_element
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Instance(text=input_element)
|
||||||
|
|
||||||
|
def process_openai_embedding_input(
|
||||||
|
self, _input: Union[list, str]
|
||||||
|
) -> List[Instance]:
|
||||||
|
"""
|
||||||
|
Process the input for multimodal embedding requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_input (Union[list, str]): The input data to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[Instance, List[Instance]]: Either a single Instance or list of Instance objects.
|
||||||
|
"""
|
||||||
|
_input_list = [_input] if not isinstance(_input, list) else _input
|
||||||
|
processed_instances = []
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(_input_list):
|
||||||
|
current = _input_list[i]
|
||||||
|
|
||||||
|
# Look ahead for potential media elements
|
||||||
|
next_elem = _input_list[i + 1] if i + 1 < len(_input_list) else None
|
||||||
|
|
||||||
|
# If current is a text and next is a GCS URI, or current is a GCS URI
|
||||||
|
if isinstance(current, str):
|
||||||
|
instance_args: Instance = {}
|
||||||
|
|
||||||
|
# Process current element
|
||||||
|
if "gs://" not in current:
|
||||||
|
instance_args["text"] = current
|
||||||
|
elif "mp4" in current:
|
||||||
|
instance_args["video"] = InstanceVideo(gcsUri=current)
|
||||||
|
else:
|
||||||
|
instance_args["image"] = InstanceImage(gcsUri=current)
|
||||||
|
|
||||||
|
# Check next element if it's a GCS URI
|
||||||
|
if next_elem and isinstance(next_elem, str) and "gs://" in next_elem:
|
||||||
|
if "mp4" in next_elem:
|
||||||
|
instance_args["video"] = InstanceVideo(gcsUri=next_elem)
|
||||||
|
else:
|
||||||
|
instance_args["image"] = InstanceImage(gcsUri=next_elem)
|
||||||
|
i += 2 # Skip next element since we processed it
|
||||||
|
else:
|
||||||
|
i += 1 # Move to next element
|
||||||
|
|
||||||
|
processed_instances.append(instance_args)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle dict or other types
|
||||||
|
if isinstance(current, dict):
|
||||||
|
instance = Instance(**current)
|
||||||
|
processed_instances.append(instance)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported input type: {type(current)}")
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return processed_instances
|
||||||
|
|
||||||
|
def transform_embedding_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: AllEmbeddingInputValues,
|
||||||
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
optional_params = optional_params or {}
|
||||||
|
|
||||||
|
request_data = VertexMultimodalEmbeddingRequest(instances=[])
|
||||||
|
|
||||||
|
if "instances" in optional_params:
|
||||||
|
request_data["instances"] = optional_params["instances"]
|
||||||
|
elif isinstance(input, list):
|
||||||
|
vertex_instances: List[Instance] = self.process_openai_embedding_input(
|
||||||
|
_input=input
|
||||||
|
)
|
||||||
|
request_data["instances"] = vertex_instances
|
||||||
|
|
||||||
|
else:
|
||||||
|
# construct instances
|
||||||
|
vertex_request_instance = Instance(**optional_params)
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
vertex_request_instance = self._process_input_element(input)
|
||||||
|
|
||||||
|
request_data["instances"] = [vertex_request_instance]
|
||||||
|
|
||||||
|
return cast(dict, request_data)
|
||||||
|
|
||||||
|
def transform_embedding_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str],
|
||||||
|
request_data: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
if raw_response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {raw_response.status_code} {raw_response.text}")
|
||||||
|
|
||||||
|
_json_response = raw_response.json()
|
||||||
|
if "predictions" not in _json_response:
|
||||||
|
raise InternalServerError(
|
||||||
|
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||||
|
llm_provider="vertex_ai",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
_predictions = _json_response["predictions"]
|
||||||
|
vertex_predictions = MultimodalPredictions(predictions=_predictions)
|
||||||
|
model_response.data = self.transform_embedding_response_to_openai(
|
||||||
|
predictions=vertex_predictions
|
||||||
|
)
|
||||||
|
model_response.model = model
|
||||||
|
|
||||||
|
model_response.usage = self.calculate_usage(
|
||||||
|
request_data=cast(VertexMultimodalEmbeddingRequest, request_data),
|
||||||
|
vertex_predictions=vertex_predictions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def calculate_usage(
|
||||||
|
self,
|
||||||
|
request_data: VertexMultimodalEmbeddingRequest,
|
||||||
|
vertex_predictions: MultimodalPredictions,
|
||||||
|
) -> Usage:
|
||||||
|
## Calculate text embeddings usage
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
character_count: Optional[int] = None
|
||||||
|
|
||||||
|
for instance in request_data["instances"]:
|
||||||
|
text = instance.get("text")
|
||||||
|
if text:
|
||||||
|
if prompt is None:
|
||||||
|
prompt = text
|
||||||
|
else:
|
||||||
|
prompt += text
|
||||||
|
|
||||||
|
if prompt is not None:
|
||||||
|
character_count = _count_characters(prompt)
|
||||||
|
|
||||||
|
## Calculate image embeddings usage
|
||||||
|
image_count = 0
|
||||||
|
for instance in request_data["instances"]:
|
||||||
|
if instance.get("image"):
|
||||||
|
image_count += 1
|
||||||
|
|
||||||
|
## Calculate video embeddings usage
|
||||||
|
video_length_seconds = 0
|
||||||
|
for prediction in vertex_predictions["predictions"]:
|
||||||
|
video_embeddings = prediction.get("videoEmbeddings")
|
||||||
|
if video_embeddings:
|
||||||
|
for embedding in video_embeddings:
|
||||||
|
duration = embedding["endOffsetSec"] - embedding["startOffsetSec"]
|
||||||
|
video_length_seconds += duration
|
||||||
|
|
||||||
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||||
|
character_count=character_count,
|
||||||
|
image_count=image_count,
|
||||||
|
video_length_seconds=video_length_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Usage(
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_embedding_response_to_openai(
|
||||||
|
self, predictions: MultimodalPredictions
|
||||||
|
) -> List[Embedding]:
|
||||||
|
|
||||||
|
openai_embeddings: List[Embedding] = []
|
||||||
|
if "predictions" in predictions:
|
||||||
|
for idx, _prediction in enumerate(predictions["predictions"]):
|
||||||
|
if _prediction:
|
||||||
|
if "textEmbedding" in _prediction:
|
||||||
|
openai_embedding_object = Embedding(
|
||||||
|
embedding=_prediction["textEmbedding"],
|
||||||
|
index=idx,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
openai_embeddings.append(openai_embedding_object)
|
||||||
|
elif "imageEmbedding" in _prediction:
|
||||||
|
openai_embedding_object = Embedding(
|
||||||
|
embedding=_prediction["imageEmbedding"],
|
||||||
|
index=idx,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
openai_embeddings.append(openai_embedding_object)
|
||||||
|
elif "videoEmbeddings" in _prediction:
|
||||||
|
for video_embedding in _prediction["videoEmbeddings"]:
|
||||||
|
openai_embedding_object = Embedding(
|
||||||
|
embedding=video_embedding["embedding"],
|
||||||
|
index=idx,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
openai_embeddings.append(openai_embedding_object)
|
||||||
|
return openai_embeddings
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return VertexAIError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
|
@ -3723,6 +3723,7 @@ def embedding( # noqa: PLR0915
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
vertex_project=vertex_ai_project,
|
vertex_project=vertex_ai_project,
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
|
|
|
@ -5409,6 +5409,23 @@
|
||||||
"supported_modalities": ["text", "image", "video"],
|
"supported_modalities": ["text", "image", "video"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
|
||||||
},
|
},
|
||||||
|
"multimodalembedding@001": {
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"max_input_tokens": 2048,
|
||||||
|
"output_vector_size": 768,
|
||||||
|
"input_cost_per_character": 0.0000002,
|
||||||
|
"input_cost_per_image": 0.0001,
|
||||||
|
"input_cost_per_video_per_second": 0.0005,
|
||||||
|
"input_cost_per_video_per_second_above_8s_interval": 0.0010,
|
||||||
|
"input_cost_per_video_per_second_above_15s_interval": 0.0020,
|
||||||
|
"input_cost_per_token": 0.0000008,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "vertex_ai-embedding-models",
|
||||||
|
"mode": "embedding",
|
||||||
|
"supported_endpoints": ["/v1/embeddings"],
|
||||||
|
"supported_modalities": ["text", "image", "video"],
|
||||||
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
|
||||||
|
},
|
||||||
"text-embedding-large-exp-03-07": {
|
"text-embedding-large-exp-03-07": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 8192,
|
"max_input_tokens": 8192,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, cast, get_type_hints
|
from typing import Any, Dict, Union, cast, get_type_hints
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||||
|
@ -78,16 +78,22 @@ class ResponsesAPIRequestUtils:
|
||||||
|
|
||||||
class ResponseAPILoggingUtils:
|
class ResponseAPILoggingUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_response_api_usage(usage: dict) -> bool:
|
def _is_response_api_usage(usage: Union[dict, ResponseAPIUsage]) -> bool:
|
||||||
"""returns True if usage is from OpenAI Response API"""
|
"""returns True if usage is from OpenAI Response API"""
|
||||||
|
if isinstance(usage, ResponseAPIUsage):
|
||||||
|
return True
|
||||||
if "input_tokens" in usage and "output_tokens" in usage:
|
if "input_tokens" in usage and "output_tokens" in usage:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _transform_response_api_usage_to_chat_usage(usage: dict) -> Usage:
|
def _transform_response_api_usage_to_chat_usage(
|
||||||
|
usage: Union[dict, ResponseAPIUsage]
|
||||||
|
) -> Usage:
|
||||||
"""Tranforms the ResponseAPIUsage object to a Usage object"""
|
"""Tranforms the ResponseAPIUsage object to a Usage object"""
|
||||||
response_api_usage: ResponseAPIUsage = ResponseAPIUsage(**usage)
|
response_api_usage: ResponseAPIUsage = (
|
||||||
|
ResponseAPIUsage(**usage) if isinstance(usage, dict) else usage
|
||||||
|
)
|
||||||
prompt_tokens: int = response_api_usage.input_tokens or 0
|
prompt_tokens: int = response_api_usage.input_tokens or 0
|
||||||
completion_tokens: int = response_api_usage.output_tokens or 0
|
completion_tokens: int = response_api_usage.output_tokens or 0
|
||||||
return Usage(
|
return Usage(
|
||||||
|
|
|
@ -191,6 +191,10 @@ class ContentBlockDeltaEvent(TypedDict, total=False):
|
||||||
reasoningContent: BedrockConverseReasoningContentBlockDelta
|
reasoningContent: BedrockConverseReasoningContentBlockDelta
|
||||||
|
|
||||||
|
|
||||||
|
class PerformanceConfigBlock(TypedDict):
|
||||||
|
latency: Literal["optimized", "throughput"]
|
||||||
|
|
||||||
|
|
||||||
class CommonRequestObject(
|
class CommonRequestObject(
|
||||||
TypedDict, total=False
|
TypedDict, total=False
|
||||||
): # common request object across sync + async flows
|
): # common request object across sync + async flows
|
||||||
|
@ -200,6 +204,7 @@ class CommonRequestObject(
|
||||||
system: List[SystemContentBlock]
|
system: List[SystemContentBlock]
|
||||||
toolConfig: ToolConfigBlock
|
toolConfig: ToolConfigBlock
|
||||||
guardrailConfig: Optional[GuardrailConfigBlock]
|
guardrailConfig: Optional[GuardrailConfigBlock]
|
||||||
|
performanceConfig: Optional[PerformanceConfigBlock]
|
||||||
|
|
||||||
|
|
||||||
class RequestObject(CommonRequestObject, total=False):
|
class RequestObject(CommonRequestObject, total=False):
|
||||||
|
@ -429,7 +434,9 @@ class AmazonNovaCanvasColorGuidedGenerationParams(TypedDict, total=False):
|
||||||
negativeText: str
|
negativeText: str
|
||||||
|
|
||||||
|
|
||||||
class AmazonNovaCanvasColorGuidedRequest(AmazonNovaCanvasRequestBase, TypedDict, total=False):
|
class AmazonNovaCanvasColorGuidedRequest(
|
||||||
|
AmazonNovaCanvasRequestBase, TypedDict, total=False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Request for Amazon Nova Canvas Color Guided Generation API
|
Request for Amazon Nova Canvas Color Guided Generation API
|
||||||
|
|
||||||
|
|
|
@ -369,9 +369,15 @@ class ResponseTuningJob(TypedDict):
|
||||||
updateTime: Optional[str]
|
updateTime: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class VideoSegmentConfig(TypedDict, total=False):
|
||||||
|
startOffsetSec: int
|
||||||
|
endOffsetSec: int
|
||||||
|
intervalSec: int
|
||||||
|
|
||||||
|
|
||||||
class InstanceVideo(TypedDict, total=False):
|
class InstanceVideo(TypedDict, total=False):
|
||||||
gcsUri: str
|
gcsUri: str
|
||||||
videoSegmentConfig: Tuple[float, float, float]
|
videoSegmentConfig: VideoSegmentConfig
|
||||||
|
|
||||||
|
|
||||||
class InstanceImage(TypedDict, total=False):
|
class InstanceImage(TypedDict, total=False):
|
||||||
|
@ -386,7 +392,7 @@ class Instance(TypedDict, total=False):
|
||||||
video: InstanceVideo
|
video: InstanceVideo
|
||||||
|
|
||||||
|
|
||||||
class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
|
class VertexMultimodalEmbeddingRequest(TypedDict):
|
||||||
instances: List[Instance]
|
instances: List[Instance]
|
||||||
|
|
||||||
|
|
||||||
|
@ -402,7 +408,7 @@ class MultimodalPrediction(TypedDict, total=False):
|
||||||
videoEmbeddings: List[VideoEmbedding]
|
videoEmbeddings: List[VideoEmbedding]
|
||||||
|
|
||||||
|
|
||||||
class MultimodalPredictions(TypedDict, total=False):
|
class MultimodalPredictions(TypedDict):
|
||||||
predictions: List[MultimodalPrediction]
|
predictions: List[MultimodalPrediction]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -779,6 +779,24 @@ class PromptTokensDetailsWrapper(
|
||||||
image_tokens: Optional[int] = None
|
image_tokens: Optional[int] = None
|
||||||
"""Image tokens sent to the model."""
|
"""Image tokens sent to the model."""
|
||||||
|
|
||||||
|
character_count: Optional[int] = None
|
||||||
|
"""Character count sent to the model. Used for Vertex AI multimodal embeddings."""
|
||||||
|
|
||||||
|
image_count: Optional[int] = None
|
||||||
|
"""Number of images sent to the model. Used for Vertex AI multimodal embeddings."""
|
||||||
|
|
||||||
|
video_length_seconds: Optional[float] = None
|
||||||
|
"""Length of videos sent to the model. Used for Vertex AI multimodal embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if self.character_count is None:
|
||||||
|
del self.character_count
|
||||||
|
if self.image_count is None:
|
||||||
|
del self.image_count
|
||||||
|
if self.video_length_seconds is None:
|
||||||
|
del self.video_length_seconds
|
||||||
|
|
||||||
|
|
||||||
class Usage(CompletionUsage):
|
class Usage(CompletionUsage):
|
||||||
_cache_creation_input_tokens: int = PrivateAttr(
|
_cache_creation_input_tokens: int = PrivateAttr(
|
||||||
|
|
|
@ -5409,6 +5409,23 @@
|
||||||
"supported_modalities": ["text", "image", "video"],
|
"supported_modalities": ["text", "image", "video"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
|
||||||
},
|
},
|
||||||
|
"multimodalembedding@001": {
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"max_input_tokens": 2048,
|
||||||
|
"output_vector_size": 768,
|
||||||
|
"input_cost_per_character": 0.0000002,
|
||||||
|
"input_cost_per_image": 0.0001,
|
||||||
|
"input_cost_per_video_per_second": 0.0005,
|
||||||
|
"input_cost_per_video_per_second_above_8s_interval": 0.0010,
|
||||||
|
"input_cost_per_video_per_second_above_15s_interval": 0.0020,
|
||||||
|
"input_cost_per_token": 0.0000008,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "vertex_ai-embedding-models",
|
||||||
|
"mode": "embedding",
|
||||||
|
"supported_endpoints": ["/v1/embeddings"],
|
||||||
|
"supported_modalities": ["text", "image", "video"],
|
||||||
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models"
|
||||||
|
},
|
||||||
"text-embedding-large-exp-03-07": {
|
"text-embedding-large-exp-03-07": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 8192,
|
"max_input_tokens": 8192,
|
||||||
|
|
|
@ -1,43 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../../../../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
|
|
||||||
from litellm.llms.vertex_ai.multimodal_embeddings.embedding_handler import (
|
|
||||||
VertexMultimodalEmbedding,
|
|
||||||
)
|
|
||||||
from litellm.types.llms.vertex_ai import Instance, InstanceImage
|
|
||||||
|
|
||||||
|
|
||||||
class TestVertexMultimodalEmbedding:
|
|
||||||
def setup_method(self):
|
|
||||||
self.embedding_handler = VertexMultimodalEmbedding()
|
|
||||||
|
|
||||||
def test_process_openai_embedding_input(self):
|
|
||||||
input_data = [
|
|
||||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
|
|
||||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
|
|
||||||
]
|
|
||||||
expected_output = [
|
|
||||||
Instance(
|
|
||||||
image=InstanceImage(bytesBase64Encoded=input_data[0].split(",")[1])
|
|
||||||
),
|
|
||||||
Instance(
|
|
||||||
image=InstanceImage(bytesBase64Encoded=input_data[1].split(",")[1])
|
|
||||||
),
|
|
||||||
]
|
|
||||||
assert (
|
|
||||||
self.embedding_handler._process_input_element(input_data[0])
|
|
||||||
== expected_output[0]
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
self.embedding_handler._process_input_element(input_data[1])
|
|
||||||
== expected_output[1]
|
|
||||||
)
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.llms.vertex_ai.multimodal_embeddings.transformation import (
|
||||||
|
VertexAIMultimodalEmbeddingConfig,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import Instance, InstanceImage, InstanceVideo
|
||||||
|
|
||||||
|
|
||||||
|
class TestVertexMultimodalEmbedding:
|
||||||
|
def setup_method(self):
|
||||||
|
self.config = VertexAIMultimodalEmbeddingConfig()
|
||||||
|
|
||||||
|
def test_process_openai_embedding_input(self):
|
||||||
|
input_data = [
|
||||||
|
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
|
||||||
|
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
|
||||||
|
]
|
||||||
|
expected_output = [
|
||||||
|
Instance(
|
||||||
|
image=InstanceImage(bytesBase64Encoded=input_data[0].split(",")[1])
|
||||||
|
),
|
||||||
|
Instance(
|
||||||
|
image=InstanceImage(bytesBase64Encoded=input_data[1].split(",")[1])
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert self.config._process_input_element(input_data[0]) == expected_output[0]
|
||||||
|
assert self.config._process_input_element(input_data[1]) == expected_output[1]
|
||||||
|
|
||||||
|
def test_process_str_and_video_input(self):
|
||||||
|
input_data = ["hi", "gs://my-bucket/embeddings/supermarket-video.mp4"]
|
||||||
|
expected_output = [
|
||||||
|
Instance(
|
||||||
|
text="hi",
|
||||||
|
video=InstanceVideo(
|
||||||
|
gcsUri="gs://my-bucket/embeddings/supermarket-video.mp4"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert self.config.process_openai_embedding_input(input_data) == expected_output
|
||||||
|
|
||||||
|
def test_process_list_of_str_and_str_input(self):
|
||||||
|
input_data = ["hi", "hello"]
|
||||||
|
expected_output = [
|
||||||
|
Instance(text="hi"),
|
||||||
|
Instance(text="hello"),
|
||||||
|
]
|
||||||
|
assert self.config.process_openai_embedding_input(input_data) == expected_output
|
||||||
|
|
||||||
|
def test_process_list_of_str_and_video_input(self):
|
||||||
|
input_data = [
|
||||||
|
"hi",
|
||||||
|
"hello",
|
||||||
|
"gs://my-bucket/embeddings/supermarket-video.mp4",
|
||||||
|
"hey",
|
||||||
|
]
|
||||||
|
expected_output = [
|
||||||
|
Instance(text="hi"),
|
||||||
|
Instance(
|
||||||
|
text="hello",
|
||||||
|
video=InstanceVideo(
|
||||||
|
gcsUri="gs://my-bucket/embeddings/supermarket-video.mp4"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Instance(text="hey"),
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
self.config.process_openai_embedding_input(input_data) == expected_output
|
||||||
|
), f"Expected {expected_output}, but got {self.config.process_openai_embedding_input(input_data)}"
|
|
@ -239,3 +239,23 @@ async def test_url_with_format_param_openai(model, sync_mode):
|
||||||
json_str = json.dumps(mock_client.call_args.kwargs)
|
json_str = json.dumps(mock_client.call_args.kwargs)
|
||||||
|
|
||||||
assert "format" not in json_str
|
assert "format" not in json_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_latency_optimized_inference():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
performanceConfig={"latency": "optimized"},
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert json_data["performanceConfig"]["latency"] == "optimized"
|
||||||
|
|
|
@ -416,6 +416,7 @@ def validate_web_search_annotations(annotations: ChatCompletionAnnotation):
|
||||||
validate_response_url_citation(url_citation)
|
validate_response_url_citation(url_citation)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(reruns=3)
|
||||||
def test_openai_web_search():
|
def test_openai_web_search():
|
||||||
"""Makes a simple web search request and validates the response contains web search annotations and all expected fields are present"""
|
"""Makes a simple web search request and validates the response contains web search annotations and all expected fields are present"""
|
||||||
litellm._turn_on_debug()
|
litellm._turn_on_debug()
|
||||||
|
|
|
@ -2197,6 +2197,55 @@ def test_vertexai_embedding_embedding_latest():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertexai_multimodalembedding_embedding_latest():
|
||||||
|
try:
|
||||||
|
import requests, base64
|
||||||
|
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
response = embedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=["hi"],
|
||||||
|
dimensions=1,
|
||||||
|
auto_truncate=True,
|
||||||
|
task_type="RETRIEVAL_QUERY",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"response.usage: {response.usage}")
|
||||||
|
assert response.usage is not None
|
||||||
|
assert response.usage.prompt_tokens_details is not None
|
||||||
|
|
||||||
|
assert response._hidden_params["response_cost"] > 0
|
||||||
|
print(f"response:", response)
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertexai_embedding_embedding_latest():
|
||||||
|
try:
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
response = embedding(
|
||||||
|
model="vertex_ai/text-embedding-004",
|
||||||
|
input=["hi"],
|
||||||
|
dimensions=1,
|
||||||
|
auto_truncate=True,
|
||||||
|
task_type="RETRIEVAL_QUERY",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.data[0]["embedding"]) == 1
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
print(f"response:", response)
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="need to get gecko permissions on vertex ai to run this test")
|
@pytest.mark.skip(reason="need to get gecko permissions on vertex ai to run this test")
|
||||||
@pytest.mark.flaky(retries=3, delay=1)
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
def test_vertexai_embedding_embedding_latest_input_type():
|
def test_vertexai_embedding_embedding_latest_input_type():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue