mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* remove data:image/jpeg;base64, prefix from base64 image input vertex_ai's multimodal embeddings endpoint expects a raw base64 string without `data:image/jpeg;base64,` prefix. * Add Vertex Multimodal Embedding Test * fix(test_vertex.py): add e2e tests on multimodal embeddings * test: unit testing * test: remove sklearn dep * test: update test with fixed route * test: fix test --------- Co-authored-by: Jonarod <jonrodd@gmail.com> Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
302 lines
11 KiB
Python
302 lines
11 KiB
Python
import json
|
|
from typing import List, Literal, Optional, Union
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
AsyncHTTPHandler,
|
|
HTTPHandler,
|
|
get_async_httpx_client,
|
|
)
|
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
|
VertexAIError,
|
|
VertexLLM,
|
|
)
|
|
from litellm.types.llms.vertex_ai import (
|
|
Instance,
|
|
InstanceImage,
|
|
InstanceVideo,
|
|
MultimodalPredictions,
|
|
VertexMultimodalEmbeddingRequest,
|
|
)
|
|
from litellm.types.utils import Embedding, EmbeddingResponse
|
|
from litellm.utils import is_base64_encoded
|
|
|
|
|
|
class VertexMultimodalEmbedding(VertexLLM):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
|
|
"multimodalembedding",
|
|
"multimodalembedding@001",
|
|
]
|
|
|
|
def multimodal_embedding(
|
|
self,
|
|
model: str,
|
|
input: Union[list, str],
|
|
print_verbose,
|
|
model_response: EmbeddingResponse,
|
|
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
|
optional_params: dict,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
encoding=None,
|
|
vertex_project=None,
|
|
vertex_location=None,
|
|
vertex_credentials=None,
|
|
aembedding=False,
|
|
timeout=300,
|
|
client=None,
|
|
) -> EmbeddingResponse:
|
|
|
|
_auth_header, vertex_project = self._ensure_access_token(
|
|
credentials=vertex_credentials,
|
|
project_id=vertex_project,
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
|
|
auth_header, url = self._get_token_and_url(
|
|
model=model,
|
|
auth_header=_auth_header,
|
|
gemini_api_key=api_key,
|
|
vertex_project=vertex_project,
|
|
vertex_location=vertex_location,
|
|
vertex_credentials=vertex_credentials,
|
|
stream=None,
|
|
custom_llm_provider=custom_llm_provider,
|
|
api_base=api_base,
|
|
should_use_v1beta1_features=False,
|
|
mode="embedding",
|
|
)
|
|
|
|
if client is None:
|
|
_params = {}
|
|
if timeout is not None:
|
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
|
_httpx_timeout = httpx.Timeout(timeout)
|
|
_params["timeout"] = _httpx_timeout
|
|
else:
|
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
|
|
|
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
|
else:
|
|
sync_handler = client # type: ignore
|
|
|
|
optional_params = optional_params or {}
|
|
|
|
request_data = VertexMultimodalEmbeddingRequest()
|
|
|
|
if "instances" in optional_params:
|
|
request_data["instances"] = optional_params["instances"]
|
|
elif isinstance(input, list):
|
|
vertex_instances: List[Instance] = self.process_openai_embedding_input(
|
|
_input=input
|
|
)
|
|
request_data["instances"] = vertex_instances
|
|
|
|
else:
|
|
# construct instances
|
|
vertex_request_instance = Instance(**optional_params)
|
|
|
|
if isinstance(input, str):
|
|
vertex_request_instance = self._process_input_element(input)
|
|
|
|
request_data["instances"] = [vertex_request_instance]
|
|
|
|
headers = {
|
|
"Content-Type": "application/json; charset=utf-8",
|
|
"Authorization": f"Bearer {auth_header}",
|
|
}
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=input,
|
|
api_key="",
|
|
additional_args={
|
|
"complete_input_dict": request_data,
|
|
"api_base": url,
|
|
"headers": headers,
|
|
},
|
|
)
|
|
|
|
if aembedding is True:
|
|
return self.async_multimodal_embedding( # type: ignore
|
|
model=model,
|
|
api_base=url,
|
|
data=request_data,
|
|
timeout=timeout,
|
|
headers=headers,
|
|
client=client,
|
|
model_response=model_response,
|
|
)
|
|
|
|
response = sync_handler.post(
|
|
url=url,
|
|
headers=headers,
|
|
data=json.dumps(request_data),
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
|
|
|
_json_response = response.json()
|
|
if "predictions" not in _json_response:
|
|
raise litellm.InternalServerError(
|
|
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
|
llm_provider="vertex_ai",
|
|
model=model,
|
|
)
|
|
_predictions = _json_response["predictions"]
|
|
vertex_predictions = MultimodalPredictions(predictions=_predictions)
|
|
model_response.data = self.transform_embedding_response_to_openai(
|
|
predictions=vertex_predictions
|
|
)
|
|
model_response.model = model
|
|
|
|
return model_response
|
|
|
|
async def async_multimodal_embedding(
|
|
self,
|
|
model: str,
|
|
api_base: str,
|
|
data: VertexMultimodalEmbeddingRequest,
|
|
model_response: litellm.EmbeddingResponse,
|
|
timeout: Optional[Union[float, httpx.Timeout]],
|
|
headers={},
|
|
client: Optional[AsyncHTTPHandler] = None,
|
|
) -> litellm.EmbeddingResponse:
|
|
if client is None:
|
|
_params = {}
|
|
if timeout is not None:
|
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
|
timeout = httpx.Timeout(timeout)
|
|
_params["timeout"] = timeout
|
|
client = get_async_httpx_client(
|
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
|
params={"timeout": timeout},
|
|
)
|
|
else:
|
|
client = client # type: ignore
|
|
|
|
try:
|
|
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as err:
|
|
error_code = err.response.status_code
|
|
raise VertexAIError(status_code=error_code, message=err.response.text)
|
|
except httpx.TimeoutException:
|
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
|
|
|
_json_response = response.json()
|
|
if "predictions" not in _json_response:
|
|
raise litellm.InternalServerError(
|
|
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
|
llm_provider="vertex_ai",
|
|
model=model,
|
|
)
|
|
_predictions = _json_response["predictions"]
|
|
|
|
vertex_predictions = MultimodalPredictions(predictions=_predictions)
|
|
model_response.data = self.transform_embedding_response_to_openai(
|
|
predictions=vertex_predictions
|
|
)
|
|
model_response.model = model
|
|
|
|
return model_response
|
|
|
|
def _process_input_element(self, input_element: str) -> Instance:
|
|
"""
|
|
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
|
|
|
|
Args:
|
|
input_element (str): The input element to process.
|
|
|
|
Returns:
|
|
Dict[str, Any]: A dictionary representing the processed input element.
|
|
"""
|
|
if len(input_element) == 0:
|
|
return Instance(text=input_element)
|
|
elif "gs://" in input_element:
|
|
if "mp4" in input_element:
|
|
return Instance(video=InstanceVideo(gcsUri=input_element))
|
|
else:
|
|
return Instance(image=InstanceImage(gcsUri=input_element))
|
|
elif is_base64_encoded(s=input_element):
|
|
return Instance(
|
|
image=InstanceImage(
|
|
bytesBase64Encoded=(
|
|
input_element.split(",")[1]
|
|
if "," in input_element
|
|
else input_element
|
|
)
|
|
)
|
|
)
|
|
else:
|
|
return Instance(text=input_element)
|
|
|
|
def process_openai_embedding_input(
|
|
self, _input: Union[list, str]
|
|
) -> List[Instance]:
|
|
"""
|
|
Process the input for multimodal embedding requests.
|
|
|
|
Args:
|
|
_input (Union[list, str]): The input data to process.
|
|
|
|
Returns:
|
|
List[Instance]: A list of processed VertexAI Instance objects.
|
|
"""
|
|
|
|
_input_list = None
|
|
if not isinstance(_input, list):
|
|
_input_list = [_input]
|
|
else:
|
|
_input_list = _input
|
|
|
|
processed_instances = []
|
|
for element in _input_list:
|
|
if isinstance(element, str):
|
|
instance = Instance(**self._process_input_element(element))
|
|
elif isinstance(element, dict):
|
|
instance = Instance(**element)
|
|
else:
|
|
raise ValueError(f"Unsupported input type: {type(element)}")
|
|
processed_instances.append(instance)
|
|
|
|
return processed_instances
|
|
|
|
def transform_embedding_response_to_openai(
|
|
self, predictions: MultimodalPredictions
|
|
) -> List[Embedding]:
|
|
|
|
openai_embeddings: List[Embedding] = []
|
|
if "predictions" in predictions:
|
|
for idx, _prediction in enumerate(predictions["predictions"]):
|
|
if _prediction:
|
|
if "textEmbedding" in _prediction:
|
|
openai_embedding_object = Embedding(
|
|
embedding=_prediction["textEmbedding"],
|
|
index=idx,
|
|
object="embedding",
|
|
)
|
|
openai_embeddings.append(openai_embedding_object)
|
|
elif "imageEmbedding" in _prediction:
|
|
openai_embedding_object = Embedding(
|
|
embedding=_prediction["imageEmbedding"],
|
|
index=idx,
|
|
object="embedding",
|
|
)
|
|
openai_embeddings.append(openai_embedding_object)
|
|
elif "videoEmbeddings" in _prediction:
|
|
for video_embedding in _prediction["videoEmbeddings"]:
|
|
openai_embedding_object = Embedding(
|
|
embedding=video_embedding["embedding"],
|
|
index=idx,
|
|
object="embedding",
|
|
)
|
|
openai_embeddings.append(openai_embedding_object)
|
|
return openai_embeddings
|