mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(vertex_ai.py): support passing in result of tool call to vertex
Fixes https://github.com/BerriAI/litellm/issues/3709
This commit is contained in:
parent
5d3fe52a08
commit
a2c66ed4fb
5 changed files with 485 additions and 46 deletions
|
@ -3,10 +3,15 @@ import json
|
|||
from enum import Enum
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, Union, List
|
||||
from typing import Callable, Optional, Union, List, Literal
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx, inspect # type: ignore
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_to_gemini_tool_call_result,
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -283,6 +288,129 @@ def _load_image_from_url(image_url: str):
|
|||
return Image.from_bytes(data=image_bytes)
|
||||
|
||||
|
||||
def _convert_gemini_role(role: str) -> Literal["user", "model"]:
|
||||
if role == "user":
|
||||
return "user"
|
||||
else:
|
||||
return "model"
|
||||
|
||||
|
||||
def _process_gemini_image(image_url: str):
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||
)
|
||||
from vertexai.preview.generative_models import Part
|
||||
|
||||
if "gs://" in image_url:
|
||||
# Case 1: Images with Cloud Storage URIs
|
||||
# The supported MIME types for images include image/png and image/jpeg.
|
||||
part_mime = "image/png" if "png" in image_url else "image/jpeg"
|
||||
google_clooud_part = Part.from_uri(image_url, mime_type=part_mime)
|
||||
return google_clooud_part
|
||||
elif "https:/" in image_url:
|
||||
# Case 2: Images with direct links
|
||||
image = _load_image_from_url(image_url)
|
||||
return image
|
||||
elif ".mp4" in image_url and "gs://" in image_url:
|
||||
# Case 3: Videos with Cloud Storage URIs
|
||||
part_mime = "video/mp4"
|
||||
google_clooud_part = Part.from_uri(image_url, mime_type=part_mime)
|
||||
return google_clooud_part
|
||||
elif "base64" in image_url:
|
||||
# Case 4: Images with base64 encoding
|
||||
import base64, re
|
||||
|
||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||
image_metadata, img_without_base_64 = image_url.split(",")
|
||||
|
||||
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||
# Extract MIME type using regular expression
|
||||
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||
|
||||
if mime_type_match:
|
||||
mime_type = mime_type_match.group(1)
|
||||
else:
|
||||
mime_type = "image/jpeg"
|
||||
decoded_img = base64.b64decode(img_without_base_64)
|
||||
processed_image = Part.from_data(data=decoded_img, mime_type=mime_type)
|
||||
return processed_image
|
||||
|
||||
|
||||
def _gemini_convert_messages_text(messages: list) -> List[ContentType]:
|
||||
"""
|
||||
Converts given messages from OpenAI format to Gemini format
|
||||
|
||||
- Parts must be iterable
|
||||
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
|
||||
- Please ensure that function response turn comes immediately after a function call turn
|
||||
"""
|
||||
user_message_types = {"user", "system"}
|
||||
contents: List[ContentType] = []
|
||||
|
||||
msg_i = 0
|
||||
while msg_i < len(messages):
|
||||
user_content: List[PartType] = []
|
||||
init_msg_i = msg_i
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
_parts: List[PartType] = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
_part = _process_gemini_image(image_url=image_url)
|
||||
_parts.append(_part) # type: ignore
|
||||
user_content.extend(_parts)
|
||||
else:
|
||||
_part = PartType(text=messages[msg_i]["content"])
|
||||
user_content.append(_part)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if user_content:
|
||||
contents.append(ContentType(role="user", parts=user_content))
|
||||
assistant_content = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
assistant_text = (
|
||||
messages[msg_i].get("content") or ""
|
||||
) # either string or none
|
||||
if assistant_text:
|
||||
assistant_content.append(PartType(text=assistant_text))
|
||||
if messages[msg_i].get(
|
||||
"tool_calls", []
|
||||
): # support assistant tool invoke convertion
|
||||
assistant_content.extend(
|
||||
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
|
||||
)
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
contents.append(ContentType(role="model", parts=assistant_content))
|
||||
|
||||
## APPEND TOOL CALL MESSAGES ##
|
||||
if messages[msg_i]["role"] == "tool":
|
||||
_part = convert_to_gemini_tool_call_result(messages[msg_i])
|
||||
contents.append(ContentType(parts=[_part])) # type: ignore
|
||||
msg_i += 1
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise Exception(
|
||||
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||
messages[msg_i]
|
||||
)
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def _gemini_vision_convert_messages(messages: list):
|
||||
"""
|
||||
Converts given messages for GPT-4 Vision to Gemini format.
|
||||
|
@ -574,11 +702,9 @@ def completion(
|
|||
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
tools = optional_params.pop("tools", None)
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
content = _gemini_convert_messages_text(messages=messages)
|
||||
stream = optional_params.pop("stream", False)
|
||||
if stream == True:
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -590,7 +716,7 @@ def completion(
|
|||
)
|
||||
|
||||
model_response = llm_model.generate_content(
|
||||
contents=content,
|
||||
contents={"content": content},
|
||||
generation_config=optional_params,
|
||||
safety_settings=safety_settings,
|
||||
stream=True,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue