diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index b0d1ed698..852599cbe 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -1684,17 +1684,21 @@ Usage +Using GCS Images + ```python response = await litellm.aembedding( model="vertex_ai/multimodalembedding@001", - input=[ - { - "image": { - "gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" - }, - "text": "this is a unicorn", - }, - ], + input="gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" # will be sent as a gcs image +) +``` + +Using base 64 encoded images + +```python +response = await litellm.aembedding( + model="vertex_ai/multimodalembedding@001", + input="data:image/jpeg;base64,..." # will be sent as a base64 encoded image ) ``` @@ -1721,9 +1725,15 @@ litellm_settings: $ litellm --config /path/to/config.yaml ``` -3. Make Request use OpenAI Python SDK +3. Make Request use OpenAI Python SDK, Langchain Python SDK + + + + +Requests with GCS Image / Video URI + ```python import openai @@ -1732,23 +1742,13 @@ client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000") # # request sent to model set on litellm proxy, `litellm --model` response = client.embeddings.create( model="multimodalembedding@001", - input = None, - extra_body = { - "instances": [ - { - "image": { - "bytesBase64Encoded": "base64" - }, - "text": "this is a unicorn", - }, - ], - } + input = "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png", ) print(response) ``` - +Requests with base64 encoded images ```python import openai @@ -1758,23 +1758,63 @@ client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000") # # request sent to model set on litellm proxy, `litellm --model` response = client.embeddings.create( model="multimodalembedding@001", - input = None, - extra_body = { - "instances": [ - { - "image": { - "gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" - }, - "text": "this is a unicorn", - }, - ], - } + input = "data:image/jpeg;base64,...", ) print(response) ``` + + + +Requests with GCS Image / Video URI +```python +from langchain_openai import OpenAIEmbeddings + +embeddings_models = "multimodalembedding@001" + +embeddings = OpenAIEmbeddings( + model="multimodalembedding@001", + base_url="http://0.0.0.0:4000", + api_key="sk-1234", # type: ignore +) + + +query_result = embeddings.embed_query( + "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" +) +print(query_result) + +``` + +Requests with base64 encoded images + +```python +from langchain_openai import OpenAIEmbeddings + +embeddings_models = "multimodalembedding@001" + +embeddings = OpenAIEmbeddings( + model="multimodalembedding@001", + base_url="http://0.0.0.0:4000", + api_key="sk-1234", # type: ignore +) + + +query_result = embeddings.embed_query( + "data:image/jpeg;base64,..." +) +print(query_result) + +``` + + + + + + + 1. Add model to config.yaml diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py index 180939556..0eda7d875 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import httpx import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexAIError, @@ -11,9 +12,14 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu ) from litellm.types.llms.vertex_ai import ( Instance, + InstanceImage, InstanceVideo, + MultimodalPrediction, + MultimodalPredictions, VertexMultimodalEmbeddingRequest, ) +from litellm.types.utils import Embedding +from litellm.utils import is_base64_encoded class VertexMultimodalEmbedding(VertexLLM): @@ -32,9 +38,9 @@ class VertexMultimodalEmbedding(VertexLLM): model_response: litellm.EmbeddingResponse, custom_llm_provider: Literal["gemini", "vertex_ai"], optional_params: dict, + logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, api_base: Optional[str] = None, - logging_obj=None, encoding=None, vertex_project=None, vertex_location=None, @@ -94,7 +100,7 @@ class VertexMultimodalEmbedding(VertexLLM): vertex_request_instance = Instance(**optional_params) if isinstance(input, str): - vertex_request_instance["text"] = input + vertex_request_instance = self._process_input_element(input) request_data["instances"] = [vertex_request_instance] @@ -142,8 +148,10 @@ class VertexMultimodalEmbedding(VertexLLM): model=model, ) _predictions = _json_response["predictions"] - - model_response.data = _predictions + vertex_predictions = MultimodalPredictions(predictions=_predictions) + model_response.data = self.transform_embedding_response_to_openai( + predictions=vertex_predictions + ) model_response.model = model return model_response @@ -186,11 +194,36 @@ class VertexMultimodalEmbedding(VertexLLM): ) _predictions = _json_response["predictions"] - model_response.data = _predictions + vertex_predictions = MultimodalPredictions(predictions=_predictions) + model_response.data = self.transform_embedding_response_to_openai( + predictions=vertex_predictions + ) model_response.model = model return model_response + def _process_input_element(self, input_element: str) -> Instance: + """ + Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text. + + Args: + input_element (str): The input element to process. + + Returns: + Dict[str, Any]: A dictionary representing the processed input element. + """ + if len(input_element) == 0: + return Instance(text=input_element) + elif "gs://" in input_element: + if "mp4" in input_element: + return Instance(video=InstanceVideo(gcsUri=input_element)) + else: + return Instance(image=InstanceImage(gcsUri=input_element)) + elif is_base64_encoded(s=input_element): + return Instance(image=InstanceImage(bytesBase64Encoded=input_element)) + else: + return Instance(text=input_element) + def process_openai_embedding_input( self, _input: Union[list, str] ) -> List[Instance]: @@ -211,14 +244,45 @@ class VertexMultimodalEmbedding(VertexLLM): _input_list = _input processed_instances = [] - for element in _input: - if not isinstance(element, dict): - # assuming that input is a list of strings - # example: input = ["hello from litellm"] - instance = Instance(text=element) - else: - # assume this is a + for element in _input_list: + if isinstance(element, str): + instance = Instance(**self._process_input_element(element)) + elif isinstance(element, dict): instance = Instance(**element) + else: + raise ValueError(f"Unsupported input type: {type(element)}") processed_instances.append(instance) return processed_instances + + def transform_embedding_response_to_openai( + self, predictions: MultimodalPredictions + ) -> List[Embedding]: + + openai_embeddings: List[Embedding] = [] + if "predictions" in predictions: + for idx, _prediction in enumerate(predictions["predictions"]): + if _prediction: + if "textEmbedding" in _prediction: + openai_embedding_object = Embedding( + embedding=_prediction["textEmbedding"], + index=idx, + object="embedding", + ) + openai_embeddings.append(openai_embedding_object) + elif "imageEmbedding" in _prediction: + openai_embedding_object = Embedding( + embedding=_prediction["imageEmbedding"], + index=idx, + object="embedding", + ) + openai_embeddings.append(openai_embedding_object) + elif "videoEmbeddings" in _prediction: + for video_embedding in _prediction["videoEmbeddings"]: + openai_embedding_object = Embedding( + embedding=video_embedding["embedding"], + index=idx, + object="embedding", + ) + openai_embeddings.append(openai_embedding_object) + return openai_embeddings diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index c8ca606a9..a2a2d4ce1 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,4 +1,14 @@ model_list: + - model_name: multimodalembedding@001 + litellm_params: + model: vertex_ai/multimodalembedding@001 + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" + - model_name: text-embedding-ada-002 + litellm_params: + model: openai/text-embedding-ada-002 # The `openai/` prefix will call openai.chat.completions.create + api_key: os.environ/OPENAI_API_KEY - model_name: db-openai-endpoint litellm_params: model: openai/gpt-3.5-turbo @@ -23,11 +33,10 @@ general_settings: service_account_settings: enforced_params: ["user"] + litellm_settings: - cache: true - # callbacks: ["otel"] + drop_params: True + callbacks: ["otel"] + success_callback: ["langfuse"] + failure_callback: ["langfuse"] - -general_settings: - service_account_settings: - enforced_params: ["user"] diff --git a/litellm/proxy/tests/test_langchain_embedding.py b/litellm/proxy/tests/test_langchain_embedding.py new file mode 100644 index 000000000..69ef54148 --- /dev/null +++ b/litellm/proxy/tests/test_langchain_embedding.py @@ -0,0 +1,17 @@ +from langchain_openai import OpenAIEmbeddings + +embeddings_models = "multimodalembedding@001" + +embeddings = OpenAIEmbeddings( + model="multimodalembedding@001", + base_url="http://0.0.0.0:4000", + api_key="sk-1234", # type: ignore +) + + +query_result = embeddings.embed_query( + "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" +) +# print(len(query_result)) +# print(query_result[:5]) +print(query_result) diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 5dfdd7c4b..f50f0eb25 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -1931,8 +1931,6 @@ async def test_vertexai_multimodal_embedding(): assert response.model == "multimodalembedding@001" assert len(response.data) == 1 response_data = response.data[0] - assert "imageEmbedding" in response_data - assert "textEmbedding" in response_data # Optional: Print for debugging print("Arguments passed to Vertex AI:", args_to_vertexai) @@ -1987,7 +1985,121 @@ async def test_vertexai_multimodal_embedding_text_input(): assert response.model == "multimodalembedding@001" assert len(response.data) == 1 response_data = response.data[0] - assert "textEmbedding" in response_data + assert response_data["embedding"] == [0.4, 0.5, 0.6] + + # Optional: Print for debugging + print("Arguments passed to Vertex AI:", args_to_vertexai) + print("Response:", response) + + +@pytest.mark.asyncio +async def test_vertexai_multimodal_embedding_image_in_input(): + load_vertex_ai_credentials() + mock_response = AsyncMock() + + def return_val(): + return { + "predictions": [ + { + "imageEmbedding": [0.1, 0.2, 0.3], # Simplified example + } + ] + } + + mock_response.json = return_val + mock_response.status_code = 200 + + expected_payload = { + "instances": [ + { + "image": { + "gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" + }, + } + ] + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.aembedding function + response = await litellm.aembedding( + model="vertex_ai/multimodalembedding@001", + input=["gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"], + ) + + # Assert + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_vertexai = kwargs["json"] + + print("args to vertex ai call:", args_to_vertexai) + + assert args_to_vertexai == expected_payload + assert response.model == "multimodalembedding@001" + assert len(response.data) == 1 + response_data = response.data[0] + + assert response_data["embedding"] == [0.1, 0.2, 0.3] + + # Optional: Print for debugging + print("Arguments passed to Vertex AI:", args_to_vertexai) + print("Response:", response) + + +@pytest.mark.asyncio +async def test_vertexai_multimodal_embedding_base64image_in_input(): + load_vertex_ai_credentials() + mock_response = AsyncMock() + + image_path = "../proxy/cached_logo.jpg" + # Getting the base64 string + base64_image = encode_image(image_path) + + def return_val(): + return { + "predictions": [ + { + "imageEmbedding": [0.1, 0.2, 0.3], # Simplified example + } + ] + } + + mock_response.json = return_val + mock_response.status_code = 200 + + expected_payload = { + "instances": [ + { + "image": {"bytesBase64Encoded": base64_image}, + } + ] + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.aembedding function + response = await litellm.aembedding( + model="vertex_ai/multimodalembedding@001", + input=[base64_image], + ) + + # Assert + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_vertexai = kwargs["json"] + + print("args to vertex ai call:", args_to_vertexai) + + assert args_to_vertexai == expected_payload + assert response.model == "multimodalembedding@001" + assert len(response.data) == 1 + response_data = response.data[0] + + assert response_data["embedding"] == [0.1, 0.2, 0.3] # Optional: Print for debugging print("Arguments passed to Vertex AI:", args_to_vertexai) diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 145aaa359..465752f5b 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -339,9 +339,15 @@ class InstanceVideo(TypedDict, total=False): videoSegmentConfig: Tuple[float, float, float] +class InstanceImage(TypedDict, total=False): + gcsUri: Optional[str] + bytesBase64Encoded: Optional[str] + mimeType: Optional[str] + + class Instance(TypedDict, total=False): text: str - image: Dict[str, str] + image: InstanceImage video: InstanceVideo @@ -349,6 +355,22 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False): instances: List[Instance] +class VideoEmbedding(TypedDict): + startOffsetSec: int + endOffsetSec: int + embedding: List[float] + + +class MultimodalPrediction(TypedDict, total=False): + textEmbedding: List[float] + imageEmbedding: List[float] + videoEmbeddings: List[VideoEmbedding] + + +class MultimodalPredictions(TypedDict, total=False): + predictions: List[MultimodalPrediction] + + class VertexAICachedContentResponseObject(TypedDict): name: str model: str