litellm-mirror/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
Krish Dholakia 6ba3c4a4f8
VertexAI non-jsonl file storage support (#9781)
* test: add initial e2e test

* fix(vertex_ai/files): initial commit adding sync file create support

* refactor: initial commit of vertex ai non-jsonl files reaching gcp endpoint

* fix(vertex_ai/files/transformation.py): initial working commit of non-jsonl file call reaching backend endpoint

* fix(vertex_ai/files/transformation.py): working e2e non-jsonl file upload

* test: working e2e jsonl call

* test: unit testing for jsonl file creation

* fix(vertex_ai/transformation.py): reset file pointer after read

allow multiple reads on same file object

* fix: fix linting errors

* fix: fix ruff linting errors

* fix: fix import

* fix: fix linting error

* fix: fix linting error

* fix(vertex_ai/files/transformation.py): fix linting error

* test: update test

* test: update tests

* fix: fix linting errors

* fix: fix test

* fix: fix linting error
2025-04-09 14:01:48 -07:00

188 lines
6.1 KiB
Python

import json
from typing import Literal, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError,
VertexLLM,
)
from litellm.types.utils import EmbeddingResponse
from .transformation import VertexAIMultimodalEmbeddingConfig
vertex_multimodal_embedding_handler = VertexAIMultimodalEmbeddingConfig()
class VertexMultimodalEmbedding(VertexLLM):
def __init__(self) -> None:
super().__init__()
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
"multimodalembedding",
"multimodalembedding@001",
]
def multimodal_embedding(
self,
model: str,
input: Union[list, str],
print_verbose,
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
litellm_params: dict,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
headers: dict = {},
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
) -> EmbeddingResponse:
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, url = self._get_token_and_url(
model=model,
auth_header=_auth_header,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="embedding",
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
request_data = vertex_multimodal_embedding_handler.transform_embedding_request(
model, input, optional_params, headers
)
headers = vertex_multimodal_embedding_handler.validate_environment(
headers=headers,
model=model,
messages=[],
optional_params=optional_params,
api_key=auth_header,
api_base=api_base,
litellm_params=litellm_params,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True:
return self.async_multimodal_embedding( # type: ignore
model=model,
api_base=url,
data=request_data,
timeout=timeout,
headers=headers,
client=client,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
api_key=api_key,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
return vertex_multimodal_embedding_handler.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
optional_params=optional_params,
litellm_params=litellm_params,
)
async def async_multimodal_embedding(
self,
model: str,
api_base: str,
optional_params: dict,
litellm_params: dict,
data: dict,
model_response: litellm.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: LiteLLMLoggingObj,
headers={},
client: Optional[AsyncHTTPHandler] = None,
api_key: Optional[str] = None,
) -> litellm.EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": timeout},
)
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
return vertex_multimodal_embedding_handler.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
optional_params=optional_params,
litellm_params=litellm_params,
)