[Vertex Multimodal embeddings] Fixes to work with Langchain OpenAI Embedding (#5949)

* fix parallel request limiter - use one cache update call

* ci/cd run again

* run ci/cd again

* use docker username password

* fix config.yml

* fix config

* fix config

* fix config.yml

* ci/cd run again

* use correct typing for batch set cache

* fix async_set_cache_pipeline

* fix only check user id tpm / rpm limits when limits set

* fix test_openai_azure_embedding_with_oidc_and_cf

* add InstanceImage type

* fix vertex image transform

* add langchain vertex test request

* add new vertex test

* update multimodal embedding tests

* add test_vertexai_multimodal_embedding_base64image_in_input

* simplify langchain mm embedding usage

* add langchain example for multimodal embeddings on vertex

* fix linting error
This commit is contained in:
Ishaan Jaff 2024-09-27 18:04:03 -07:00 committed by GitHub
parent bd17424c4b
commit fd87ae69b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 318 additions and 54 deletions

View file

@ -1684,17 +1684,21 @@ Usage
<Tabs>
<TabItem value="sdk" label="SDK">
Using GCS Images
```python
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=[
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
},
],
input="gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" # will be sent as a gcs image
)
```
Using base 64 encoded images
```python
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input="data:image/jpeg;base64,..." # will be sent as a base64 encoded image
)
```
@ -1721,9 +1725,15 @@ litellm_settings:
$ litellm --config /path/to/config.yaml
```
3. Make Request use OpenAI Python SDK
3. Make Request use OpenAI Python SDK, Langchain Python SDK
<Tabs>
<TabItem value="OpenAI SDK" label="OpenAI SDK">
Requests with GCS Image / Video URI
```python
import openai
@ -1732,23 +1742,13 @@ client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
# # request sent to model set on litellm proxy, `litellm --model`
response = client.embeddings.create(
model="multimodalembedding@001",
input = None,
extra_body = {
"instances": [
{
"image": {
"bytesBase64Encoded": "base64"
},
"text": "this is a unicorn",
},
],
}
input = "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
)
print(response)
```
Requests with base64 encoded images
```python
import openai
@ -1758,23 +1758,63 @@ client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
# # request sent to model set on litellm proxy, `litellm --model`
response = client.embeddings.create(
model="multimodalembedding@001",
input = None,
extra_body = {
"instances": [
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
},
],
}
input = "data:image/jpeg;base64,...",
)
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
Requests with GCS Image / Video URI
```python
from langchain_openai import OpenAIEmbeddings
embeddings_models = "multimodalembedding@001"
embeddings = OpenAIEmbeddings(
model="multimodalembedding@001",
base_url="http://0.0.0.0:4000",
api_key="sk-1234", # type: ignore
)
query_result = embeddings.embed_query(
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
)
print(query_result)
```
Requests with base64 encoded images
```python
from langchain_openai import OpenAIEmbeddings
embeddings_models = "multimodalembedding@001"
embeddings = OpenAIEmbeddings(
model="multimodalembedding@001",
base_url="http://0.0.0.0:4000",
api_key="sk-1234", # type: ignore
)
query_result = embeddings.embed_query(
"data:image/jpeg;base64,..."
)
print(query_result)
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="proxy-vtx" label="LiteLLM PROXY (Vertex SDK)">
1. Add model to config.yaml

View file

@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError,
@ -11,9 +12,14 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu
)
from litellm.types.llms.vertex_ai import (
Instance,
InstanceImage,
InstanceVideo,
MultimodalPrediction,
MultimodalPredictions,
VertexMultimodalEmbeddingRequest,
)
from litellm.types.utils import Embedding
from litellm.utils import is_base64_encoded
class VertexMultimodalEmbedding(VertexLLM):
@ -32,9 +38,9 @@ class VertexMultimodalEmbedding(VertexLLM):
model_response: litellm.EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
@ -94,7 +100,7 @@ class VertexMultimodalEmbedding(VertexLLM):
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance["text"] = input
vertex_request_instance = self._process_input_element(input)
request_data["instances"] = [vertex_request_instance]
@ -142,8 +148,10 @@ class VertexMultimodalEmbedding(VertexLLM):
model=model,
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
vertex_predictions = MultimodalPredictions(predictions=_predictions)
model_response.data = self.transform_embedding_response_to_openai(
predictions=vertex_predictions
)
model_response.model = model
return model_response
@ -186,11 +194,36 @@ class VertexMultimodalEmbedding(VertexLLM):
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
vertex_predictions = MultimodalPredictions(predictions=_predictions)
model_response.data = self.transform_embedding_response_to_openai(
predictions=vertex_predictions
)
model_response.model = model
return model_response
def _process_input_element(self, input_element: str) -> Instance:
"""
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
Args:
input_element (str): The input element to process.
Returns:
Dict[str, Any]: A dictionary representing the processed input element.
"""
if len(input_element) == 0:
return Instance(text=input_element)
elif "gs://" in input_element:
if "mp4" in input_element:
return Instance(video=InstanceVideo(gcsUri=input_element))
else:
return Instance(image=InstanceImage(gcsUri=input_element))
elif is_base64_encoded(s=input_element):
return Instance(image=InstanceImage(bytesBase64Encoded=input_element))
else:
return Instance(text=input_element)
def process_openai_embedding_input(
self, _input: Union[list, str]
) -> List[Instance]:
@ -211,14 +244,45 @@ class VertexMultimodalEmbedding(VertexLLM):
_input_list = _input
processed_instances = []
for element in _input:
if not isinstance(element, dict):
# assuming that input is a list of strings
# example: input = ["hello from litellm"]
instance = Instance(text=element)
else:
# assume this is a
for element in _input_list:
if isinstance(element, str):
instance = Instance(**self._process_input_element(element))
elif isinstance(element, dict):
instance = Instance(**element)
else:
raise ValueError(f"Unsupported input type: {type(element)}")
processed_instances.append(instance)
return processed_instances
def transform_embedding_response_to_openai(
self, predictions: MultimodalPredictions
) -> List[Embedding]:
openai_embeddings: List[Embedding] = []
if "predictions" in predictions:
for idx, _prediction in enumerate(predictions["predictions"]):
if _prediction:
if "textEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["textEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "imageEmbedding" in _prediction:
openai_embedding_object = Embedding(
embedding=_prediction["imageEmbedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
elif "videoEmbeddings" in _prediction:
for video_embedding in _prediction["videoEmbeddings"]:
openai_embedding_object = Embedding(
embedding=video_embedding["embedding"],
index=idx,
object="embedding",
)
openai_embeddings.append(openai_embedding_object)
return openai_embeddings

View file

@ -1,4 +1,14 @@
model_list:
- model_name: multimodalembedding@001
litellm_params:
model: vertex_ai/multimodalembedding@001
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json"
- model_name: text-embedding-ada-002
litellm_params:
model: openai/text-embedding-ada-002 # The `openai/` prefix will call openai.chat.completions.create
api_key: os.environ/OPENAI_API_KEY
- model_name: db-openai-endpoint
litellm_params:
model: openai/gpt-3.5-turbo
@ -23,11 +33,10 @@ general_settings:
service_account_settings:
enforced_params: ["user"]
litellm_settings:
cache: true
# callbacks: ["otel"]
drop_params: True
callbacks: ["otel"]
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
general_settings:
service_account_settings:
enforced_params: ["user"]

View file

@ -0,0 +1,17 @@
from langchain_openai import OpenAIEmbeddings
embeddings_models = "multimodalembedding@001"
embeddings = OpenAIEmbeddings(
model="multimodalembedding@001",
base_url="http://0.0.0.0:4000",
api_key="sk-1234", # type: ignore
)
query_result = embeddings.embed_query(
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
)
# print(len(query_result))
# print(query_result[:5])
print(query_result)

View file

@ -1931,8 +1931,6 @@ async def test_vertexai_multimodal_embedding():
assert response.model == "multimodalembedding@001"
assert len(response.data) == 1
response_data = response.data[0]
assert "imageEmbedding" in response_data
assert "textEmbedding" in response_data
# Optional: Print for debugging
print("Arguments passed to Vertex AI:", args_to_vertexai)
@ -1987,7 +1985,121 @@ async def test_vertexai_multimodal_embedding_text_input():
assert response.model == "multimodalembedding@001"
assert len(response.data) == 1
response_data = response.data[0]
assert "textEmbedding" in response_data
assert response_data["embedding"] == [0.4, 0.5, 0.6]
# Optional: Print for debugging
print("Arguments passed to Vertex AI:", args_to_vertexai)
print("Response:", response)
@pytest.mark.asyncio
async def test_vertexai_multimodal_embedding_image_in_input():
load_vertex_ai_credentials()
mock_response = AsyncMock()
def return_val():
return {
"predictions": [
{
"imageEmbedding": [0.1, 0.2, 0.3], # Simplified example
}
]
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"instances": [
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
}
]
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.aembedding function
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=["gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"],
)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_vertexai = kwargs["json"]
print("args to vertex ai call:", args_to_vertexai)
assert args_to_vertexai == expected_payload
assert response.model == "multimodalembedding@001"
assert len(response.data) == 1
response_data = response.data[0]
assert response_data["embedding"] == [0.1, 0.2, 0.3]
# Optional: Print for debugging
print("Arguments passed to Vertex AI:", args_to_vertexai)
print("Response:", response)
@pytest.mark.asyncio
async def test_vertexai_multimodal_embedding_base64image_in_input():
load_vertex_ai_credentials()
mock_response = AsyncMock()
image_path = "../proxy/cached_logo.jpg"
# Getting the base64 string
base64_image = encode_image(image_path)
def return_val():
return {
"predictions": [
{
"imageEmbedding": [0.1, 0.2, 0.3], # Simplified example
}
]
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"instances": [
{
"image": {"bytesBase64Encoded": base64_image},
}
]
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.aembedding function
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=[base64_image],
)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_vertexai = kwargs["json"]
print("args to vertex ai call:", args_to_vertexai)
assert args_to_vertexai == expected_payload
assert response.model == "multimodalembedding@001"
assert len(response.data) == 1
response_data = response.data[0]
assert response_data["embedding"] == [0.1, 0.2, 0.3]
# Optional: Print for debugging
print("Arguments passed to Vertex AI:", args_to_vertexai)

View file

@ -339,9 +339,15 @@ class InstanceVideo(TypedDict, total=False):
videoSegmentConfig: Tuple[float, float, float]
class InstanceImage(TypedDict, total=False):
gcsUri: Optional[str]
bytesBase64Encoded: Optional[str]
mimeType: Optional[str]
class Instance(TypedDict, total=False):
text: str
image: Dict[str, str]
image: InstanceImage
video: InstanceVideo
@ -349,6 +355,22 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
instances: List[Instance]
class VideoEmbedding(TypedDict):
startOffsetSec: int
endOffsetSec: int
embedding: List[float]
class MultimodalPrediction(TypedDict, total=False):
textEmbedding: List[float]
imageEmbedding: List[float]
videoEmbeddings: List[VideoEmbedding]
class MultimodalPredictions(TypedDict, total=False):
predictions: List[MultimodalPrediction]
class VertexAICachedContentResponseObject(TypedDict):
name: str
model: str