diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 8fc67c0c2..a88925330 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -9,7 +9,7 @@ import types import uuid from enum import Enum from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore @@ -38,12 +38,15 @@ from litellm.types.llms.vertex_ai import ( FunctionDeclaration, GenerateContentResponseBody, GenerationConfig, + Instance, + InstanceVideo, PartType, RequestBody, SafetSettingsConfig, SystemInstructions, ToolConfig, Tools, + VertexMultimodalEmbeddingRequest, ) from litellm.types.utils import GenericStreamingChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage @@ -598,6 +601,10 @@ class VertexLLM(BaseLLM): self._credentials: Optional[Any] = None self.project_id: Optional[str] = None self.async_handler: Optional[AsyncHTTPHandler] = None + self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [ + "multimodalembedding", + "multimodalembedding@001", + ] def _process_response( self, @@ -1541,6 +1548,160 @@ class VertexLLM(BaseLLM): return model_response + def multimodal_embedding( + self, + model: str, + input: Union[list, str], + print_verbose, + model_response: litellm.EmbeddingResponse, + optional_params: dict, + api_key: Optional[str] = None, + logging_obj=None, + encoding=None, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=False, + timeout=300, + client=None, + ): + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + + auth_header, _ = self._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + optional_params = optional_params or {} + + request_data = VertexMultimodalEmbeddingRequest() + + if "instances" in optional_params: + request_data["instances"] = optional_params["instances"] + elif isinstance(input, list): + request_data["instances"] = input + else: + # construct instances + vertex_request_instance = Instance(**optional_params) + + if isinstance(input, str): + vertex_request_instance["text"] = input + + request_data["instances"] = [vertex_request_instance] + + request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + logging_obj.pre_call( + input=[], + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + logging_obj.pre_call( + input=[], + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + } + + if aembedding is True: + return self.async_multimodal_embedding( + model=model, + api_base=url, + data=request_data, + timeout=timeout, + headers=headers, + client=client, + model_response=model_response, + ) + + response = sync_handler.post( + url=url, + headers=headers, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + if "predictions" not in _json_response: + raise litellm.InternalServerError( + message=f"embedding response does not contain 'predictions', got {_json_response}", + llm_provider="vertex_ai", + model=model, + ) + _predictions = _json_response["predictions"] + + model_response.data = _predictions + model_response.model = model + + return model_response + + async def async_multimodal_embedding( + self, + model: str, + api_base: str, + data: VertexMultimodalEmbeddingRequest, + model_response: litellm.EmbeddingResponse, + timeout: Optional[Union[float, httpx.Timeout]], + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> litellm.EmbeddingResponse: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = AsyncHTTPHandler(**_params) # type: ignore + else: + client = client # type: ignore + + try: + response = await client.post(api_base, headers=headers, json=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + _json_response = response.json() + if "predictions" not in _json_response: + raise litellm.InternalServerError( + message=f"embedding response does not contain 'predictions', got {_json_response}", + llm_provider="vertex_ai", + model=model, + ) + _predictions = _json_response["predictions"] + + model_response.data = _predictions + model_response.model = model + + return model_response + class ModelResponseIterator: def __init__(self, streaming_response, sync_stream: bool): diff --git a/litellm/main.py b/litellm/main.py index f2c6df306..ee327c2f7 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3477,19 +3477,39 @@ def embedding( or get_secret("VERTEX_CREDENTIALS") ) - response = vertex_ai.embedding( - model=model, - input=input, - encoding=encoding, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - vertex_project=vertex_ai_project, - vertex_location=vertex_ai_location, - vertex_credentials=vertex_credentials, - aembedding=aembedding, - print_verbose=print_verbose, - ) + if ( + "image" in optional_params + or "video" in optional_params + or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS + ): + # multimodal embedding is supported on vertex httpx + response = vertex_chat_completion.multimodal_embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + vertex_credentials=vertex_credentials, + aembedding=aembedding, + print_verbose=print_verbose, + ) + else: + response = vertex_ai.embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + vertex_credentials=vertex_credentials, + aembedding=aembedding, + print_verbose=print_verbose, + ) elif custom_llm_provider == "oobabooga": response = oobabooga.embedding( model=model, diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 5e61e4f52..fe7d1a8c8 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -1836,6 +1836,36 @@ def test_vertexai_embedding(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio() +async def test_vertexai_multimodal_embedding(): + load_vertex_ai_credentials() + + try: + litellm.set_verbose = True + response = await litellm.aembedding( + model="vertex_ai/multimodalembedding@001", + input=[ + { + "image": { + "gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" + }, + "text": "this is a unicorn", + }, + ], + ) + print(f"response:", response) + assert response.model == "multimodalembedding@001" + + _response_data = response.data[0] + + assert "imageEmbedding" in _response_data + assert "textEmbedding" in _response_data + except litellm.RateLimitError as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.skip( reason="new test - works locally running into vertex version issues on ci/cd" ) diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 6758c356f..5586d4861 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -1,6 +1,6 @@ import json from enum import Enum -from typing import Any, Dict, List, Literal, Optional, TypedDict, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union from typing_extensions import ( Protocol, @@ -305,3 +305,18 @@ class ResponseTuningJob(TypedDict): ] createTime: Optional[str] updateTime: Optional[str] + + +class InstanceVideo(TypedDict, total=False): + gcsUri: str + videoSegmentConfig: Tuple[float, float, float] + + +class Instance(TypedDict, total=False): + text: str + image: Dict[str, str] + video: InstanceVideo + + +class VertexMultimodalEmbeddingRequest(TypedDict, total=False): + instances: List[Instance] diff --git a/litellm/utils.py b/litellm/utils.py index 93717595e..0c15cae53 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -541,7 +541,7 @@ def function_setup( call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value ): - messages = args[1] if len(args) > 1 else kwargs["input"] + messages = args[1] if len(args) > 1 else kwargs.get("input", None) elif ( call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value