add test for test_vertexai_multimodal_embedding_text_input

This commit is contained in:
Ishaan Jaff 2024-08-30 09:19:48 -07:00
parent 4e5965137c
commit 10771e3bde
4 changed files with 278 additions and 164 deletions

View file

@ -41,15 +41,12 @@ from litellm.types.llms.vertex_ai import (
FunctionDeclaration, FunctionDeclaration,
GenerateContentResponseBody, GenerateContentResponseBody,
GenerationConfig, GenerationConfig,
Instance,
InstanceVideo,
PartType, PartType,
RequestBody, RequestBody,
SafetSettingsConfig, SafetSettingsConfig,
SystemInstructions, SystemInstructions,
ToolConfig, ToolConfig,
Tools, Tools,
VertexMultimodalEmbeddingRequest,
) )
from litellm.types.utils import GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -811,10 +808,6 @@ class VertexLLM(BaseLLM):
self._credentials: Optional[Any] = None self._credentials: Optional[Any] = None
self.project_id: Optional[str] = None self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None self.async_handler: Optional[AsyncHTTPHandler] = None
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
"multimodalembedding",
"multimodalembedding@001",
]
def _process_response( def _process_response(
self, self,
@ -1727,161 +1720,6 @@ class VertexLLM(BaseLLM):
return model_response return model_response
def multimodal_embedding(
self,
model: str,
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
):
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="embedding",
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest()
if "instances" in optional_params:
request_data["instances"] = optional_params["instances"]
elif isinstance(input, list):
request_data["instances"] = input
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance["text"] = input
request_data["instances"] = [vertex_request_instance]
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True:
return self.async_multimodal_embedding(
model=model,
api_base=url,
data=request_data,
timeout=timeout,
headers=headers,
client=client,
model_response=model_response,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
model_response.model = model
return model_response
async def async_multimodal_embedding(
self,
model: str,
api_base: str,
data: VertexMultimodalEmbeddingRequest,
model_response: litellm.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> litellm.EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
model_response.model = model
return model_response
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool): def __init__(self, streaming_response, sync_stream: bool):

View file

@ -0,0 +1,216 @@
import json
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError,
VertexLLM,
)
from litellm.types.llms.vertex_ai import (
Instance,
InstanceVideo,
VertexMultimodalEmbeddingRequest,
)
class VertexMultimodalEmbedding(VertexLLM):
def __init__(self) -> None:
super().__init__()
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
"multimodalembedding",
"multimodalembedding@001",
]
def multimodal_embedding(
self,
model: str,
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
):
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="embedding",
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest()
if "instances" in optional_params:
request_data["instances"] = optional_params["instances"]
elif isinstance(input, list):
vertex_instances: List[Instance] = self.process_openai_embedding_input(
_input=input
)
request_data["instances"] = vertex_instances
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance["text"] = input
request_data["instances"] = [vertex_request_instance]
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True:
return self.async_multimodal_embedding(
model=model,
api_base=url,
data=request_data,
timeout=timeout,
headers=headers,
client=client,
model_response=model_response,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
model_response.model = model
return model_response
async def async_multimodal_embedding(
self,
model: str,
api_base: str,
data: VertexMultimodalEmbeddingRequest,
model_response: litellm.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> litellm.EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
model_response.model = model
return model_response
def process_openai_embedding_input(
self, _input: Union[list, str]
) -> List[Instance]:
"""
Process the input for multimodal embedding requests.
Args:
_input (Union[list, str]): The input data to process.
Returns:
List[Instance]: A list of processed VertexAI Instance objects.
"""
_input_list = None
if not isinstance(_input, list):
_input_list = [_input]
else:
_input_list = _input
processed_instances = []
for element in _input:
if not isinstance(element, dict):
# assuming that input is a list of strings
# example: input = ["hello from litellm"]
instance = Instance(text=element)
else:
# assume this is a
instance = Instance(**element)
processed_instances.append(instance)
return processed_instances

View file

@ -132,6 +132,9 @@ from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import (
VertexMultimodalEmbedding,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels, VertexAIPartnerModels,
) )
@ -175,6 +178,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
vertex_multimodal_embedding = VertexMultimodalEmbedding()
google_batch_embeddings = GoogleBatchEmbeddings() google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI() vertex_text_to_speech = VertexTextToSpeechAPI()
@ -3581,10 +3585,11 @@ def embedding(
if ( if (
"image" in optional_params "image" in optional_params
or "video" in optional_params or "video" in optional_params
or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS or model
in vertex_multimodal_embedding.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
): ):
# multimodal embedding is supported on vertex httpx # multimodal embedding is supported on vertex httpx
response = vertex_chat_completion.multimodal_embedding( response = vertex_multimodal_embedding.multimodal_embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,

View file

@ -1934,6 +1934,61 @@ async def test_vertexai_multimodal_embedding():
print("Response:", response) print("Response:", response)
@pytest.mark.asyncio
async def test_vertexai_multimodal_embedding_text_input():
load_vertex_ai_credentials()
mock_response = AsyncMock()
def return_val():
return {
"predictions": [
{
"textEmbedding": [0.4, 0.5, 0.6], # Simplified example
}
]
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"instances": [
{
"text": "this is a unicorn",
}
]
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.aembedding function
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=[
"this is a unicorn",
],
)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_vertexai = kwargs["json"]
print("args to vertex ai call:", args_to_vertexai)
assert args_to_vertexai == expected_payload
assert response.model == "multimodalembedding@001"
assert len(response.data) == 1
response_data = response.data[0]
assert "textEmbedding" in response_data
# Optional: Print for debugging
print("Arguments passed to Vertex AI:", args_to_vertexai)
print("Response:", response)
@pytest.mark.skip( @pytest.mark.skip(
reason="new test - works locally running into vertex version issues on ci/cd" reason="new test - works locally running into vertex version issues on ci/cd"
) )