add multi modal vtx embedding

This commit is contained in:
Ishaan Jaff 2024-08-21 15:05:59 -07:00
parent 7e3dc83c0d
commit 35781ab8d5
4 changed files with 109 additions and 164 deletions

View file

@ -9,7 +9,7 @@ import types
import uuid import uuid
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -597,6 +597,10 @@ class VertexLLM(BaseLLM):
self._credentials: Optional[Any] = None self._credentials: Optional[Any] = None
self.project_id: Optional[str] = None self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None self.async_handler: Optional[AsyncHTTPHandler] = None
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
"multimodalembedding",
"multimodalembedding@001",
]
def _process_response( def _process_response(
self, self,
@ -1557,19 +1561,6 @@ class VertexLLM(BaseLLM):
timeout=300, timeout=300,
client=None, client=None,
): ):
# if aembedding is True:
# return self.aimage_generation(
# prompt=prompt,
# vertex_project=vertex_project,
# vertex_location=vertex_location,
# vertex_credentials=vertex_credentials,
# model=model,
# client=client,
# optional_params=optional_params,
# timeout=timeout,
# logging_obj=logging_obj,
# model_response=model_response,
# )
if client is None: if client is None:
_params = {} _params = {}
@ -1592,24 +1583,21 @@ class VertexLLM(BaseLLM):
optional_params = optional_params or {} optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest() request_data = VertexMultimodalEmbeddingRequest()
vertex_request_instance = Instance(**optional_params)
# if "image" in optional_params: if "instances" in optional_params:
# vertex_request_instance["image"] = optional_params["image"] request_data["instances"] = optional_params["instances"]
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
# if "video" in optional_params: if isinstance(input, str):
# vertex_request_instance["video"] = optional_params["video"] vertex_request_instance["text"] = input
# if "text" in optional_params: request_data["instances"] = [vertex_request_instance]
# vertex_request_instance["text"] = optional_params["text"]
if isinstance(input, str):
vertex_request_instance["text"] = input
request_data["instances"] = [vertex_request_instance]
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=[],
api_key=None, api_key=None,
additional_args={ additional_args={
"complete_input_dict": optional_params, "complete_input_dict": optional_params,
@ -1618,7 +1606,7 @@ class VertexLLM(BaseLLM):
) )
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=[],
api_key=None, api_key=None,
additional_args={ additional_args={
"complete_input_dict": optional_params, "complete_input_dict": optional_params,
@ -1626,32 +1614,30 @@ class VertexLLM(BaseLLM):
}, },
) )
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
if aembedding is True:
return self.async_multimodal_embedding(
model=model,
api_base=url,
data=request_data,
timeout=timeout,
headers=headers,
client=client,
model_response=model_response,
)
response = sync_handler.post( response = sync_handler.post(
url=url, url=url,
headers={ headers=headers,
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data), data=json.dumps(request_data),
) )
if response.status_code != 200: if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}") raise Exception(f"Error: {response.status_code} {response.text}")
"""
Vertex AI Image generation response example:
{
"predictions": [
{
"bytesBase64Encoded": "BASE64_IMG_BYTES",
"mimeType": "image/png"
},
{
"mimeType": "image/png",
"bytesBase64Encoded": "BASE64_IMG_BYTES"
}
]
}
"""
_json_response = response.json() _json_response = response.json()
if "predictions" not in _json_response: if "predictions" not in _json_response:
@ -1667,125 +1653,48 @@ class VertexLLM(BaseLLM):
return model_response return model_response
# async def aimage_generation( async def async_multimodal_embedding(
# self, self,
# prompt: str, model: str,
# vertex_project: Optional[str], api_base: str,
# vertex_location: Optional[str], data: VertexMultimodalEmbeddingRequest,
# vertex_credentials: Optional[str], model_response: litellm.EmbeddingResponse,
# model_response: litellm.ImageResponse, timeout: Optional[Union[float, httpx.Timeout]],
# model: Optional[ headers={},
# str client: Optional[AsyncHTTPHandler] = None,
# ] = "imagegeneration", # vertex ai uses imagegeneration as the default model ) -> litellm.EmbeddingResponse:
# client: Optional[AsyncHTTPHandler] = None, if client is None:
# optional_params: Optional[dict] = None, _params = {}
# timeout: Optional[int] = None, if timeout is not None:
# logging_obj=None, if isinstance(timeout, float) or isinstance(timeout, int):
# ): timeout = httpx.Timeout(timeout)
# response = None _params["timeout"] = timeout
# if client is None: client = AsyncHTTPHandler(**_params) # type: ignore
# _params = {} else:
# if timeout is not None: client = client # type: ignore
# if isinstance(timeout, float) or isinstance(timeout, int):
# _httpx_timeout = httpx.Timeout(timeout)
# _params["timeout"] = _httpx_timeout
# else:
# _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
# self.async_handler = AsyncHTTPHandler(**_params) # type: ignore try:
# else: response = await client.post(api_base, headers=headers, json=data) # type: ignore
# self.async_handler = client # type: ignore response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
# # make POST request to _json_response = response.json()
# # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict if "predictions" not in _json_response:
# url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
# """ model_response.data = _predictions
# Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 model_response.model = model
# curl -X POST \
# -H "Authorization: Bearer $(gcloud auth print-access-token)" \
# -H "Content-Type: application/json; charset=utf-8" \
# -d {
# "instances": [
# {
# "prompt": "a cat"
# }
# ],
# "parameters": {
# "sampleCount": 1
# }
# } \
# "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
# """
# auth_header, _ = self._ensure_access_token(
# credentials=vertex_credentials, project_id=vertex_project
# )
# optional_params = optional_params or {
# "sampleCount": 1
# } # default optional params
# request_data = { return model_response
# "instances": [{"prompt": prompt}],
# "parameters": optional_params,
# }
# request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
# logging_obj.pre_call(
# input=prompt,
# api_key=None,
# additional_args={
# "complete_input_dict": optional_params,
# "request_str": request_str,
# },
# )
# response = await self.async_handler.post(
# url=url,
# headers={
# "Content-Type": "application/json; charset=utf-8",
# "Authorization": f"Bearer {auth_header}",
# },
# data=json.dumps(request_data),
# )
# if response.status_code != 200:
# raise Exception(f"Error: {response.status_code} {response.text}")
# """
# Vertex AI Image generation response example:
# {
# "predictions": [
# {
# "bytesBase64Encoded": "BASE64_IMG_BYTES",
# "mimeType": "image/png"
# },
# {
# "mimeType": "image/png",
# "bytesBase64Encoded": "BASE64_IMG_BYTES"
# }
# ]
# }
# """
# _json_response = response.json()
# if "predictions" not in _json_response:
# raise litellm.InternalServerError(
# message=f"image generation response does not contain 'predictions', got {_json_response}",
# llm_provider="vertex_ai",
# model=model,
# )
# _predictions = _json_response["predictions"]
# _response_data: List[Image] = []
# for _prediction in _predictions:
# _bytes_base64_encoded = _prediction["bytesBase64Encoded"]
# image_object = Image(b64_json=_bytes_base64_encoded)
# _response_data.append(image_object)
# model_response.data = _response_data
# return model_response
class ModelResponseIterator: class ModelResponseIterator:

View file

@ -3477,7 +3477,11 @@ def embedding(
or get_secret("VERTEX_CREDENTIALS") or get_secret("VERTEX_CREDENTIALS")
) )
if "image" in optional_params or "video" in optional_params: if (
"image" in optional_params
or "video" in optional_params
or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
):
# multimodal embedding is supported on vertex httpx # multimodal embedding is supported on vertex httpx
response = vertex_chat_completion.multimodal_embedding( response = vertex_chat_completion.multimodal_embedding(
model=model, model=model,

View file

@ -1826,6 +1826,38 @@ def test_vertexai_embedding():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
async def test_vertexai_multimodal_embedding():
image_path = "../proxy/cached_logo.jpg"
# Getting the base64 string
base64_image = encode_image(image_path)
print("base 64 img ", base64_image)
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
instances=[
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
},
],
)
print(f"response:", response)
assert response.model == "multimodalembedding@001"
_response_data = response.data[0]
assert "imageEmbedding" in _response_data
assert "textEmbedding" in _response_data
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip( @pytest.mark.skip(
reason="new test - works locally running into vertex version issues on ci/cd" reason="new test - works locally running into vertex version issues on ci/cd"
) )

View file

@ -121,7 +121,7 @@ import importlib.metadata
from openai import OpenAIError as OriginalError from openai import OpenAIError as OriginalError
from ._logging import verbose_logger from ._logging import verbose_logger
from .caching import RedisCache, RedisSemanticCache, S3Cache, QdrantSemanticCache from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
from .exceptions import ( from .exceptions import (
APIConnectionError, APIConnectionError,
APIError, APIError,
@ -541,7 +541,7 @@ def function_setup(
call_type == CallTypes.embedding.value call_type == CallTypes.embedding.value
or call_type == CallTypes.aembedding.value or call_type == CallTypes.aembedding.value
): ):
messages = args[1] if len(args) > 1 else kwargs["input"] messages = args[1] if len(args) > 1 else kwargs.get("input", None)
elif ( elif (
call_type == CallTypes.image_generation.value call_type == CallTypes.image_generation.value
or call_type == CallTypes.aimage_generation.value or call_type == CallTypes.aimage_generation.value