litellm-mirror/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py
Krish Dholakia 6ba3c4a4f8
VertexAI non-jsonl file storage support (#9781)
* test: add initial e2e test

* fix(vertex_ai/files): initial commit adding sync file create support

* refactor: initial commit of vertex ai non-jsonl files reaching gcp endpoint

* fix(vertex_ai/files/transformation.py): initial working commit of non-jsonl file call reaching backend endpoint

* fix(vertex_ai/files/transformation.py): working e2e non-jsonl file upload

* test: working e2e jsonl call

* test: unit testing for jsonl file creation

* fix(vertex_ai/transformation.py): reset file pointer after read

allow multiple reads on same file object

* fix: fix linting errors

* fix: fix ruff linting errors

* fix: fix import

* fix: fix linting error

* fix: fix linting error

* fix(vertex_ai/files/transformation.py): fix linting error

* test: update test

* test: update tests

* fix: fix linting errors

* fix: fix test

* fix: fix linting error
2025-04-09 14:01:48 -07:00

297 lines
11 KiB
Python

from typing import List, Optional, Union, cast
from httpx import Headers, Response
from litellm.exceptions import InternalServerError
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import LiteLLMLoggingObj
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.llms.vertex_ai import (
Instance,
InstanceImage,
InstanceVideo,
MultimodalPredictions,
VertexMultimodalEmbeddingRequest,
)
from litellm.types.utils import (
Embedding,
EmbeddingResponse,
PromptTokensDetailsWrapper,
Usage,
)
from litellm.utils import _count_characters, is_base64_encoded
from ...base_llm.embedding.transformation import BaseEmbeddingConfig
from ..common_utils import VertexAIError
class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
def get_supported_openai_params(self, model: str) -> list:
return ["dimensions"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["outputDimensionality"] = value
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
default_headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
headers.update(default_headers)
return headers
def _process_input_element(self, input_element: str) -> Instance:
"""
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
Args:
input_element (str): The input element to process.
Returns:
Dict[str, Any]: A dictionary representing the processed input element.
"""
if len(input_element) == 0:
return Instance(text=input_element)
elif "gs://" in input_element:
if "mp4" in input_element:
return Instance(video=InstanceVideo(gcsUri=input_element))
else:
return Instance(image=InstanceImage(gcsUri=input_element))
elif is_base64_encoded(s=input_element):
return Instance(
image=InstanceImage(
bytesBase64Encoded=(
input_element.split(",")[1]
if "," in input_element
else input_element
)
)
)
else:
return Instance(text=input_element)
def process_openai_embedding_input(
self, _input: Union[list, str]
) -> List[Instance]:
"""
Process the input for multimodal embedding requests.
Args:
_input (Union[list, str]): The input data to process.
Returns:
Union[Instance, List[Instance]]: Either a single Instance or list of Instance objects.
"""
_input_list = [_input] if not isinstance(_input, list) else _input
processed_instances = []
i = 0
while i < len(_input_list):
current = _input_list[i]
# Look ahead for potential media elements
next_elem = _input_list[i + 1] if i + 1 < len(_input_list) else None
# If current is a text and next is a GCS URI, or current is a GCS URI
if isinstance(current, str):
instance_args: Instance = {}
# Process current element
if "gs://" not in current:
instance_args["text"] = current
elif "mp4" in current:
instance_args["video"] = InstanceVideo(gcsUri=current)
else:
instance_args["image"] = InstanceImage(gcsUri=current)
# Check next element if it's a GCS URI
if next_elem and isinstance(next_elem, str) and "gs://" in next_elem:
if "mp4" in next_elem:
instance_args["video"] = InstanceVideo(gcsUri=next_elem)
else:
instance_args["image"] = InstanceImage(gcsUri=next_elem)
i += 2 # Skip next element since we processed it
else:
i += 1 # Move to next element
processed_instances.append(instance_args)
continue
# Handle dict or other types
if isinstance(current, dict):
instance = Instance(**current)
processed_instances.append(instance)
else:
raise ValueError(f"Unsupported input type: {type(current)}")
i += 1
return processed_instances
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest(instances=[])
if "instances" in optional_params:
request_data["instances"] = optional_params["instances"]
elif isinstance(input, list):
vertex_instances: List[Instance] = self.process_openai_embedding_input(
_input=input
)
request_data["instances"] = vertex_instances
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance = self._process_input_element(input)
request_data["instances"] = [vertex_request_instance]
return cast(dict, request_data)
def transform_embedding_response(
self,
model: str,
raw_response: Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
if raw_response.status_code != 200:
raise Exception(f"Error: {raw_response.status_code} {raw_response.text}")
_json_response = raw_response.json()
if "predictions" not in _json_response:
raise InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
vertex_predictions = MultimodalPredictions(predictions=_predictions)
model_response.data = self.transform_embedding_response_to_openai(
predictions=vertex_predictions
)
model_response.model = model
model_response.usage = self.calculate_usage(
request_data=cast(VertexMultimodalEmbeddingRequest, request_data),
vertex_predictions=vertex_predictions,
)
return model_response
def calculate_usage(
self,
request_data: VertexMultimodalEmbeddingRequest,
vertex_predictions: MultimodalPredictions,
) -> Usage:
## Calculate text embeddings usage
prompt: Optional[str] = None
character_count: Optional[int] = None
for instance in request_data["instances"]:
text = instance.get("text")
if text:
if prompt is None:
prompt = text
else:
prompt += text
if prompt is not None:
character_count = _count_characters(prompt)
## Calculate image embeddings usage
image_count = 0
for instance in request_data["instances"]:
if instance.get("image"):
image_count += 1
## Calculate video embeddings usage
video_length_seconds = 0
for prediction in vertex_predictions["predictions"]:
video_embeddings = prediction.get("videoEmbeddings")
if video_embeddings:
for embedding in video_embeddings:
duration = embedding["endOffsetSec"] - embedding["startOffsetSec"]
video_length_seconds += duration
prompt_tokens_details = PromptTokensDetailsWrapper(
character_count=character_count,
image_count=image_count,
video_length_seconds=video_length_seconds,
)
return Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
prompt_tokens_details=prompt_tokens_details,
)
def transform_embedding_response_to_openai(
self, predictions: MultimodalPredictions
) -> List[Embedding]:
openai_embeddings: List[Embedding] = []
if "predictions" in predictions:
for idx, _prediction in enumerate(predictions["predictions"]):
if _prediction:
if "textEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["textEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "imageEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["imageEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "videoEmbeddings" in _prediction:
for video_embedding in _prediction["videoEmbeddings"]:
openai_embedding_object = Embedding(
embedding=video_embedding["embedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
return openai_embeddings
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return VertexAIError(
status_code=status_code, message=error_message, headers=headers
)