mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
add test for test_vertexai_multimodal_embedding_text_input
This commit is contained in:
parent
4e5965137c
commit
10771e3bde
4 changed files with 278 additions and 164 deletions
|
@ -41,15 +41,12 @@ from litellm.types.llms.vertex_ai import (
|
||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
GenerateContentResponseBody,
|
GenerateContentResponseBody,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
Instance,
|
|
||||||
InstanceVideo,
|
|
||||||
PartType,
|
PartType,
|
||||||
RequestBody,
|
RequestBody,
|
||||||
SafetSettingsConfig,
|
SafetSettingsConfig,
|
||||||
SystemInstructions,
|
SystemInstructions,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
Tools,
|
Tools,
|
||||||
VertexMultimodalEmbeddingRequest,
|
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericStreamingChunk
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
@ -811,10 +808,6 @@ class VertexLLM(BaseLLM):
|
||||||
self._credentials: Optional[Any] = None
|
self._credentials: Optional[Any] = None
|
||||||
self.project_id: Optional[str] = None
|
self.project_id: Optional[str] = None
|
||||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||||
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
|
|
||||||
"multimodalembedding",
|
|
||||||
"multimodalembedding@001",
|
|
||||||
]
|
|
||||||
|
|
||||||
def _process_response(
|
def _process_response(
|
||||||
self,
|
self,
|
||||||
|
@ -1727,161 +1720,6 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def multimodal_embedding(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: Union[list, str],
|
|
||||||
print_verbose,
|
|
||||||
model_response: litellm.EmbeddingResponse,
|
|
||||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
|
||||||
optional_params: dict,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
logging_obj=None,
|
|
||||||
encoding=None,
|
|
||||||
vertex_project=None,
|
|
||||||
vertex_location=None,
|
|
||||||
vertex_credentials=None,
|
|
||||||
aembedding=False,
|
|
||||||
timeout=300,
|
|
||||||
client=None,
|
|
||||||
):
|
|
||||||
auth_header, url = self._get_token_and_url(
|
|
||||||
model=model,
|
|
||||||
gemini_api_key=api_key,
|
|
||||||
vertex_project=vertex_project,
|
|
||||||
vertex_location=vertex_location,
|
|
||||||
vertex_credentials=vertex_credentials,
|
|
||||||
stream=None,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
api_base=api_base,
|
|
||||||
should_use_v1beta1_features=False,
|
|
||||||
mode="embedding",
|
|
||||||
)
|
|
||||||
|
|
||||||
if client is None:
|
|
||||||
_params = {}
|
|
||||||
if timeout is not None:
|
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
|
||||||
_httpx_timeout = httpx.Timeout(timeout)
|
|
||||||
_params["timeout"] = _httpx_timeout
|
|
||||||
else:
|
|
||||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
|
||||||
|
|
||||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
|
||||||
else:
|
|
||||||
sync_handler = client # type: ignore
|
|
||||||
|
|
||||||
optional_params = optional_params or {}
|
|
||||||
|
|
||||||
request_data = VertexMultimodalEmbeddingRequest()
|
|
||||||
|
|
||||||
if "instances" in optional_params:
|
|
||||||
request_data["instances"] = optional_params["instances"]
|
|
||||||
elif isinstance(input, list):
|
|
||||||
request_data["instances"] = input
|
|
||||||
else:
|
|
||||||
# construct instances
|
|
||||||
vertex_request_instance = Instance(**optional_params)
|
|
||||||
|
|
||||||
if isinstance(input, str):
|
|
||||||
vertex_request_instance["text"] = input
|
|
||||||
|
|
||||||
request_data["instances"] = [vertex_request_instance]
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
|
||||||
"Authorization": f"Bearer {auth_header}",
|
|
||||||
}
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=input,
|
|
||||||
api_key="",
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": request_data,
|
|
||||||
"api_base": url,
|
|
||||||
"headers": headers,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if aembedding is True:
|
|
||||||
return self.async_multimodal_embedding(
|
|
||||||
model=model,
|
|
||||||
api_base=url,
|
|
||||||
data=request_data,
|
|
||||||
timeout=timeout,
|
|
||||||
headers=headers,
|
|
||||||
client=client,
|
|
||||||
model_response=model_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = sync_handler.post(
|
|
||||||
url=url,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(request_data),
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
|
||||||
|
|
||||||
_json_response = response.json()
|
|
||||||
if "predictions" not in _json_response:
|
|
||||||
raise litellm.InternalServerError(
|
|
||||||
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
|
||||||
llm_provider="vertex_ai",
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
_predictions = _json_response["predictions"]
|
|
||||||
|
|
||||||
model_response.data = _predictions
|
|
||||||
model_response.model = model
|
|
||||||
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
async def async_multimodal_embedding(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
api_base: str,
|
|
||||||
data: VertexMultimodalEmbeddingRequest,
|
|
||||||
model_response: litellm.EmbeddingResponse,
|
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
|
||||||
headers={},
|
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
|
||||||
) -> litellm.EmbeddingResponse:
|
|
||||||
if client is None:
|
|
||||||
_params = {}
|
|
||||||
if timeout is not None:
|
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
|
||||||
timeout = httpx.Timeout(timeout)
|
|
||||||
_params["timeout"] = timeout
|
|
||||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
|
||||||
else:
|
|
||||||
client = client # type: ignore
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
|
||||||
response.raise_for_status()
|
|
||||||
except httpx.HTTPStatusError as err:
|
|
||||||
error_code = err.response.status_code
|
|
||||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
|
||||||
|
|
||||||
_json_response = response.json()
|
|
||||||
if "predictions" not in _json_response:
|
|
||||||
raise litellm.InternalServerError(
|
|
||||||
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
|
||||||
llm_provider="vertex_ai",
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
_predictions = _json_response["predictions"]
|
|
||||||
|
|
||||||
model_response.data = _predictions
|
|
||||||
model_response.model = model
|
|
||||||
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, streaming_response, sync_stream: bool):
|
def __init__(self, streaming_response, sync_stream: bool):
|
||||||
|
|
|
@ -0,0 +1,216 @@
|
||||||
|
import json
|
||||||
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexAIError,
|
||||||
|
VertexLLM,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
Instance,
|
||||||
|
InstanceVideo,
|
||||||
|
VertexMultimodalEmbeddingRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexMultimodalEmbedding(VertexLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
|
||||||
|
"multimodalembedding",
|
||||||
|
"multimodalembedding@001",
|
||||||
|
]
|
||||||
|
|
||||||
|
def multimodal_embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: Union[list, str],
|
||||||
|
print_verbose,
|
||||||
|
model_response: litellm.EmbeddingResponse,
|
||||||
|
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
encoding=None,
|
||||||
|
vertex_project=None,
|
||||||
|
vertex_location=None,
|
||||||
|
vertex_credentials=None,
|
||||||
|
aembedding=False,
|
||||||
|
timeout=300,
|
||||||
|
client=None,
|
||||||
|
):
|
||||||
|
auth_header, url = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=api_key,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=None,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=api_base,
|
||||||
|
should_use_v1beta1_features=False,
|
||||||
|
mode="embedding",
|
||||||
|
)
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
_httpx_timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = _httpx_timeout
|
||||||
|
else:
|
||||||
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
|
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
sync_handler = client # type: ignore
|
||||||
|
|
||||||
|
optional_params = optional_params or {}
|
||||||
|
|
||||||
|
request_data = VertexMultimodalEmbeddingRequest()
|
||||||
|
|
||||||
|
if "instances" in optional_params:
|
||||||
|
request_data["instances"] = optional_params["instances"]
|
||||||
|
elif isinstance(input, list):
|
||||||
|
vertex_instances: List[Instance] = self.process_openai_embedding_input(
|
||||||
|
_input=input
|
||||||
|
)
|
||||||
|
request_data["instances"] = vertex_instances
|
||||||
|
|
||||||
|
else:
|
||||||
|
# construct instances
|
||||||
|
vertex_request_instance = Instance(**optional_params)
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
vertex_request_instance["text"] = input
|
||||||
|
|
||||||
|
request_data["instances"] = [vertex_request_instance]
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {auth_header}",
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": request_data,
|
||||||
|
"api_base": url,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if aembedding is True:
|
||||||
|
return self.async_multimodal_embedding(
|
||||||
|
model=model,
|
||||||
|
api_base=url,
|
||||||
|
data=request_data,
|
||||||
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
|
client=client,
|
||||||
|
model_response=model_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = sync_handler.post(
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(request_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
if "predictions" not in _json_response:
|
||||||
|
raise litellm.InternalServerError(
|
||||||
|
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||||
|
llm_provider="vertex_ai",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
_predictions = _json_response["predictions"]
|
||||||
|
|
||||||
|
model_response.data = _predictions
|
||||||
|
model_response.model = model
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
async def async_multimodal_embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_base: str,
|
||||||
|
data: VertexMultimodalEmbeddingRequest,
|
||||||
|
model_response: litellm.EmbeddingResponse,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> litellm.EmbeddingResponse:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
if "predictions" not in _json_response:
|
||||||
|
raise litellm.InternalServerError(
|
||||||
|
message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||||
|
llm_provider="vertex_ai",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
_predictions = _json_response["predictions"]
|
||||||
|
|
||||||
|
model_response.data = _predictions
|
||||||
|
model_response.model = model
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def process_openai_embedding_input(
|
||||||
|
self, _input: Union[list, str]
|
||||||
|
) -> List[Instance]:
|
||||||
|
"""
|
||||||
|
Process the input for multimodal embedding requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_input (Union[list, str]): The input data to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Instance]: A list of processed VertexAI Instance objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_input_list = None
|
||||||
|
if not isinstance(_input, list):
|
||||||
|
_input_list = [_input]
|
||||||
|
else:
|
||||||
|
_input_list = _input
|
||||||
|
|
||||||
|
processed_instances = []
|
||||||
|
for element in _input:
|
||||||
|
if not isinstance(element, dict):
|
||||||
|
# assuming that input is a list of strings
|
||||||
|
# example: input = ["hello from litellm"]
|
||||||
|
instance = Instance(text=element)
|
||||||
|
else:
|
||||||
|
# assume this is a
|
||||||
|
instance = Instance(**element)
|
||||||
|
processed_instances.append(instance)
|
||||||
|
|
||||||
|
return processed_instances
|
|
@ -132,6 +132,9 @@ from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler
|
||||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
|
from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import (
|
||||||
|
VertexMultimodalEmbedding,
|
||||||
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||||
VertexAIPartnerModels,
|
VertexAIPartnerModels,
|
||||||
)
|
)
|
||||||
|
@ -175,6 +178,7 @@ triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
|
vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
||||||
google_batch_embeddings = GoogleBatchEmbeddings()
|
google_batch_embeddings = GoogleBatchEmbeddings()
|
||||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||||
|
@ -3581,10 +3585,11 @@ def embedding(
|
||||||
if (
|
if (
|
||||||
"image" in optional_params
|
"image" in optional_params
|
||||||
or "video" in optional_params
|
or "video" in optional_params
|
||||||
or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
|
or model
|
||||||
|
in vertex_multimodal_embedding.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
|
||||||
):
|
):
|
||||||
# multimodal embedding is supported on vertex httpx
|
# multimodal embedding is supported on vertex httpx
|
||||||
response = vertex_chat_completion.multimodal_embedding(
|
response = vertex_multimodal_embedding.multimodal_embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
|
|
@ -1934,6 +1934,61 @@ async def test_vertexai_multimodal_embedding():
|
||||||
print("Response:", response)
|
print("Response:", response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertexai_multimodal_embedding_text_input():
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"predictions": [
|
||||||
|
{
|
||||||
|
"textEmbedding": [0.4, 0.5, 0.6], # Simplified example
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"text": "this is a unicorn",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.aembedding function
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=[
|
||||||
|
"this is a unicorn",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_vertexai = kwargs["json"]
|
||||||
|
|
||||||
|
print("args to vertex ai call:", args_to_vertexai)
|
||||||
|
|
||||||
|
assert args_to_vertexai == expected_payload
|
||||||
|
assert response.model == "multimodalembedding@001"
|
||||||
|
assert len(response.data) == 1
|
||||||
|
response_data = response.data[0]
|
||||||
|
assert "textEmbedding" in response_data
|
||||||
|
|
||||||
|
# Optional: Print for debugging
|
||||||
|
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
||||||
|
print("Response:", response)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
reason="new test - works locally running into vertex version issues on ci/cd"
|
reason="new test - works locally running into vertex version issues on ci/cd"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue