Merge pull request #5326 from BerriAI/litellm_Add_vertex_multimodal_embedding

[Feat] add vertex multimodal embedding support
This commit is contained in:
Ishaan Jaff 2024-08-21 17:06:43 -07:00 committed by GitHub
commit dd524a4f50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 242 additions and 16 deletions

View file

@ -9,7 +9,7 @@ import types
import uuid
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
@ -38,12 +38,15 @@ from litellm.types.llms.vertex_ai import (
FunctionDeclaration,
GenerateContentResponseBody,
GenerationConfig,
Instance,
InstanceVideo,
PartType,
RequestBody,
SafetSettingsConfig,
SystemInstructions,
ToolConfig,
Tools,
VertexMultimodalEmbeddingRequest,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -598,6 +601,10 @@ class VertexLLM(BaseLLM):
self._credentials: Optional[Any] = None
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
"multimodalembedding",
"multimodalembedding@001",
]
def _process_response(
self,
@ -1541,6 +1548,160 @@ class VertexLLM(BaseLLM):
return model_response
def multimodal_embedding(
self,
model: str,
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
):
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest()
if "instances" in optional_params:
request_data["instances"] = optional_params["instances"]
elif isinstance(input, list):
request_data["instances"] = input
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance["text"] = input
request_data["instances"] = [vertex_request_instance]
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
logging_obj.pre_call(
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
if aembedding is True:
return self.async_multimodal_embedding(
model=model,
api_base=url,
data=request_data,
timeout=timeout,
headers=headers,
client=client,
model_response=model_response,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
model_response.model = model
return model_response
async def async_multimodal_embedding(
self,
model: str,
api_base: str,
data: VertexMultimodalEmbeddingRequest,
model_response: litellm.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> litellm.EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
model_response.model = model
return model_response
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):

View file

@ -3477,19 +3477,39 @@ def embedding(
or get_secret("VERTEX_CREDENTIALS")
)
response = vertex_ai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
)
if (
"image" in optional_params
or "video" in optional_params
or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
):
# multimodal embedding is supported on vertex httpx
response = vertex_chat_completion.multimodal_embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
)
else:
response = vertex_ai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
)
elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding(
model=model,

View file

@ -1836,6 +1836,36 @@ def test_vertexai_embedding():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
async def test_vertexai_multimodal_embedding():
load_vertex_ai_credentials()
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=[
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
},
],
)
print(f"response:", response)
assert response.model == "multimodalembedding@001"
_response_data = response.data[0]
assert "imageEmbedding" in _response_data
assert "textEmbedding" in _response_data
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(
reason="new test - works locally running into vertex version issues on ci/cd"
)

View file

@ -1,6 +1,6 @@
import json
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
from typing_extensions import (
Protocol,
@ -305,3 +305,18 @@ class ResponseTuningJob(TypedDict):
]
createTime: Optional[str]
updateTime: Optional[str]
class InstanceVideo(TypedDict, total=False):
gcsUri: str
videoSegmentConfig: Tuple[float, float, float]
class Instance(TypedDict, total=False):
text: str
image: Dict[str, str]
video: InstanceVideo
class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
instances: List[Instance]

View file

@ -541,7 +541,7 @@ def function_setup(
call_type == CallTypes.embedding.value
or call_type == CallTypes.aembedding.value
):
messages = args[1] if len(args) > 1 else kwargs["input"]
messages = args[1] if len(args) > 1 else kwargs.get("input", None)
elif (
call_type == CallTypes.image_generation.value
or call_type == CallTypes.aimage_generation.value