feat(vertex_httpx.py): Moving to call vertex ai via httpx (instead of their sdk). Allows us to support all their api updates.

This commit is contained in:
Krrish Dholakia 2024-06-12 16:47:00 -07:00
parent 29902b49d4
commit 29169b3039
8 changed files with 431 additions and 40 deletions

View file

@ -12,7 +12,12 @@ from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result,
convert_to_gemini_tool_call_invoke,
)
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
from litellm.types.files import (
get_file_mime_type_for_file_type,
get_file_type_from_extension,
is_gemini_1_5_accepted_file_type,
is_video_file_type,
)
class VertexAIError(Exception):
@ -301,15 +306,15 @@ def _process_gemini_image(image_url: str) -> PartType:
# GCS URIs
if "gs://" in image_url:
# Figure out file type
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
extension = extension_with_dot[1:] # Ex: "png"
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
extension = extension_with_dot[1:] # Ex: "png"
file_type = get_file_type_from_extension(extension)
# Validate the file type is supported by Gemini
if not is_gemini_1_5_accepted_file_type(file_type):
raise Exception(f"File type not supported by gemini - {file_type}")
mime_type = get_file_mime_type_for_file_type(file_type)
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
@ -320,7 +325,7 @@ def _process_gemini_image(image_url: str) -> PartType:
image = _load_image_from_url(image_url)
_blob = BlobType(data=image.data, mime_type=image._mime_type)
return PartType(inline_data=_blob)
# Base64 encoding
elif "base64" in image_url:
import base64, re
@ -611,7 +616,7 @@ def completion(
llm_model = None
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
if acompletion == True:
if acompletion is True:
data = {
"llm_model": llm_model,
"mode": mode,
@ -643,7 +648,7 @@ def completion(
tools = optional_params.pop("tools", None)
content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False)
if stream == True:
if stream is True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(
input=prompt,