add multi modal vtx embedding

This commit is contained in:
Ishaan Jaff 2024-08-21 15:05:59 -07:00
parent 7e3dc83c0d
commit 35781ab8d5
4 changed files with 109 additions and 164 deletions

View file

@ -9,7 +9,7 @@ import types
import uuid
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
@ -597,6 +597,10 @@ class VertexLLM(BaseLLM):
self._credentials: Optional[Any] = None
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
"multimodalembedding",
"multimodalembedding@001",
]
def _process_response(
self,
@ -1557,19 +1561,6 @@ class VertexLLM(BaseLLM):
timeout=300,
client=None,
):
# if aembedding is True:
# return self.aimage_generation(
# prompt=prompt,
# vertex_project=vertex_project,
# vertex_location=vertex_location,
# vertex_credentials=vertex_credentials,
# model=model,
# client=client,
# optional_params=optional_params,
# timeout=timeout,
# logging_obj=logging_obj,
# model_response=model_response,
# )
if client is None:
_params = {}
@ -1592,16 +1583,13 @@ class VertexLLM(BaseLLM):
optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest()
if "instances" in optional_params:
request_data["instances"] = optional_params["instances"]
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
# if "image" in optional_params:
# vertex_request_instance["image"] = optional_params["image"]
# if "video" in optional_params:
# vertex_request_instance["video"] = optional_params["video"]
# if "text" in optional_params:
# vertex_request_instance["text"] = optional_params["text"]
if isinstance(input, str):
vertex_request_instance["text"] = input
@ -1609,7 +1597,7 @@ class VertexLLM(BaseLLM):
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=input,
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
@ -1618,7 +1606,7 @@ class VertexLLM(BaseLLM):
)
logging_obj.pre_call(
input=input,
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
@ -1626,32 +1614,30 @@ class VertexLLM(BaseLLM):
},
)
response = sync_handler.post(
url=url,
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
}
if aembedding is True:
return self.async_multimodal_embedding(
model=model,
api_base=url,
data=request_data,
timeout=timeout,
headers=headers,
client=client,
model_response=model_response,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
"""
Vertex AI Image generation response example:
{
"predictions": [
{
"bytesBase64Encoded": "BASE64_IMG_BYTES",
"mimeType": "image/png"
},
{
"mimeType": "image/png",
"bytesBase64Encoded": "BASE64_IMG_BYTES"
}
]
}
"""
_json_response = response.json()
if "predictions" not in _json_response:
@ -1667,125 +1653,48 @@ class VertexLLM(BaseLLM):
return model_response
# async def aimage_generation(
# self,
# prompt: str,
# vertex_project: Optional[str],
# vertex_location: Optional[str],
# vertex_credentials: Optional[str],
# model_response: litellm.ImageResponse,
# model: Optional[
# str
# ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
# client: Optional[AsyncHTTPHandler] = None,
# optional_params: Optional[dict] = None,
# timeout: Optional[int] = None,
# logging_obj=None,
# ):
# response = None
# if client is None:
# _params = {}
# if timeout is not None:
# if isinstance(timeout, float) or isinstance(timeout, int):
# _httpx_timeout = httpx.Timeout(timeout)
# _params["timeout"] = _httpx_timeout
# else:
# _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async def async_multimodal_embedding(
self,
model: str,
api_base: str,
data: VertexMultimodalEmbeddingRequest,
model_response: litellm.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> litellm.EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
# self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
# else:
# self.async_handler = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
# # make POST request to
# # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
# url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
# """
# Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
# curl -X POST \
# -H "Authorization: Bearer $(gcloud auth print-access-token)" \
# -H "Content-Type: application/json; charset=utf-8" \
# -d {
# "instances": [
# {
# "prompt": "a cat"
# }
# ],
# "parameters": {
# "sampleCount": 1
# }
# } \
# "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
# """
# auth_header, _ = self._ensure_access_token(
# credentials=vertex_credentials, project_id=vertex_project
# )
# optional_params = optional_params or {
# "sampleCount": 1
# } # default optional params
model_response.data = _predictions
model_response.model = model
# request_data = {
# "instances": [{"prompt": prompt}],
# "parameters": optional_params,
# }
# request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
# logging_obj.pre_call(
# input=prompt,
# api_key=None,
# additional_args={
# "complete_input_dict": optional_params,
# "request_str": request_str,
# },
# )
# response = await self.async_handler.post(
# url=url,
# headers={
# "Content-Type": "application/json; charset=utf-8",
# "Authorization": f"Bearer {auth_header}",
# },
# data=json.dumps(request_data),
# )
# if response.status_code != 200:
# raise Exception(f"Error: {response.status_code} {response.text}")
# """
# Vertex AI Image generation response example:
# {
# "predictions": [
# {
# "bytesBase64Encoded": "BASE64_IMG_BYTES",
# "mimeType": "image/png"
# },
# {
# "mimeType": "image/png",
# "bytesBase64Encoded": "BASE64_IMG_BYTES"
# }
# ]
# }
# """
# _json_response = response.json()
# if "predictions" not in _json_response:
# raise litellm.InternalServerError(
# message=f"image generation response does not contain 'predictions', got {_json_response}",
# llm_provider="vertex_ai",
# model=model,
# )
# _predictions = _json_response["predictions"]
# _response_data: List[Image] = []
# for _prediction in _predictions:
# _bytes_base64_encoded = _prediction["bytesBase64Encoded"]
# image_object = Image(b64_json=_bytes_base64_encoded)
# _response_data.append(image_object)
# model_response.data = _response_data
# return model_response
return model_response
class ModelResponseIterator:

View file

@ -3477,7 +3477,11 @@ def embedding(
or get_secret("VERTEX_CREDENTIALS")
)
if "image" in optional_params or "video" in optional_params:
if (
"image" in optional_params
or "video" in optional_params
or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
):
# multimodal embedding is supported on vertex httpx
response = vertex_chat_completion.multimodal_embedding(
model=model,

View file

@ -1826,6 +1826,38 @@ def test_vertexai_embedding():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
async def test_vertexai_multimodal_embedding():
image_path = "../proxy/cached_logo.jpg"
# Getting the base64 string
base64_image = encode_image(image_path)
print("base 64 img ", base64_image)
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
instances=[
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
},
],
)
print(f"response:", response)
assert response.model == "multimodalembedding@001"
_response_data = response.data[0]
assert "imageEmbedding" in _response_data
assert "textEmbedding" in _response_data
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(
reason="new test - works locally running into vertex version issues on ci/cd"
)

View file

@ -121,7 +121,7 @@ import importlib.metadata
from openai import OpenAIError as OriginalError
from ._logging import verbose_logger
from .caching import RedisCache, RedisSemanticCache, S3Cache, QdrantSemanticCache
from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
from .exceptions import (
APIConnectionError,
APIError,
@ -541,7 +541,7 @@ def function_setup(
call_type == CallTypes.embedding.value
or call_type == CallTypes.aembedding.value
):
messages = args[1] if len(args) > 1 else kwargs["input"]
messages = args[1] if len(args) > 1 else kwargs.get("input", None)
elif (
call_type == CallTypes.image_generation.value
or call_type == CallTypes.aimage_generation.value