mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(vertex_ai.py): passing all tests on 'test_amazing_vertex_completion.py
This commit is contained in:
parent
a2c66ed4fb
commit
f9ab72841a
3 changed files with 4246 additions and 57 deletions
|
@ -295,49 +295,45 @@ def _convert_gemini_role(role: str) -> Literal["user", "model"]:
|
|||
return "model"
|
||||
|
||||
|
||||
def _process_gemini_image(image_url: str):
|
||||
def _process_gemini_image(image_url: str) -> PartType:
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||
)
|
||||
from vertexai.preview.generative_models import Part
|
||||
if "gs://" in image_url:
|
||||
# Case 1: Images with Cloud Storage URIs
|
||||
# The supported MIME types for images include image/png and image/jpeg.
|
||||
part_mime = "image/png" if "png" in image_url else "image/jpeg"
|
||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
||||
return PartType(file_data=_file_data)
|
||||
elif "https:/" in image_url:
|
||||
# Case 2: Images with direct links
|
||||
image = _load_image_from_url(image_url)
|
||||
_blob = BlobType(data=image.data, mime_type=image._mime_type)
|
||||
return PartType(inline_data=_blob)
|
||||
elif ".mp4" in image_url and "gs://" in image_url:
|
||||
# Case 3: Videos with Cloud Storage URIs
|
||||
part_mime = "video/mp4"
|
||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
||||
return PartType(file_data=_file_data)
|
||||
elif "base64" in image_url:
|
||||
# Case 4: Images with base64 encoding
|
||||
import base64, re
|
||||
|
||||
if "gs://" in image_url:
|
||||
# Case 1: Images with Cloud Storage URIs
|
||||
# The supported MIME types for images include image/png and image/jpeg.
|
||||
part_mime = "image/png" if "png" in image_url else "image/jpeg"
|
||||
google_clooud_part = Part.from_uri(image_url, mime_type=part_mime)
|
||||
return google_clooud_part
|
||||
elif "https:/" in image_url:
|
||||
# Case 2: Images with direct links
|
||||
image = _load_image_from_url(image_url)
|
||||
return image
|
||||
elif ".mp4" in image_url and "gs://" in image_url:
|
||||
# Case 3: Videos with Cloud Storage URIs
|
||||
part_mime = "video/mp4"
|
||||
google_clooud_part = Part.from_uri(image_url, mime_type=part_mime)
|
||||
return google_clooud_part
|
||||
elif "base64" in image_url:
|
||||
# Case 4: Images with base64 encoding
|
||||
import base64, re
|
||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||
image_metadata, img_without_base_64 = image_url.split(",")
|
||||
|
||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||
image_metadata, img_without_base_64 = image_url.split(",")
|
||||
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||
# Extract MIME type using regular expression
|
||||
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||
|
||||
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||
# Extract MIME type using regular expression
|
||||
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||
|
||||
if mime_type_match:
|
||||
mime_type = mime_type_match.group(1)
|
||||
else:
|
||||
mime_type = "image/jpeg"
|
||||
decoded_img = base64.b64decode(img_without_base_64)
|
||||
processed_image = Part.from_data(data=decoded_img, mime_type=mime_type)
|
||||
return processed_image
|
||||
if mime_type_match:
|
||||
mime_type = mime_type_match.group(1)
|
||||
else:
|
||||
mime_type = "image/jpeg"
|
||||
decoded_img = base64.b64decode(img_without_base_64)
|
||||
_blob = BlobType(data=decoded_img, mime_type=mime_type)
|
||||
return PartType(inline_data=_blob)
|
||||
raise Exception("Invalid image received - {}".format(image_url))
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _gemini_convert_messages_text(messages: list) -> List[ContentType]:
|
||||
|
@ -397,7 +393,7 @@ def _gemini_convert_messages_text(messages: list) -> List[ContentType]:
|
|||
contents.append(ContentType(role="model", parts=assistant_content))
|
||||
|
||||
## APPEND TOOL CALL MESSAGES ##
|
||||
if messages[msg_i]["role"] == "tool":
|
||||
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||
_part = convert_to_gemini_tool_call_result(messages[msg_i])
|
||||
contents.append(ContentType(parts=[_part])) # type: ignore
|
||||
msg_i += 1
|
||||
|
@ -524,10 +520,10 @@ def completion(
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
|
@ -715,15 +711,15 @@ def completion(
|
|||
},
|
||||
)
|
||||
|
||||
model_response = llm_model.generate_content(
|
||||
contents={"content": content},
|
||||
_model_response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=optional_params,
|
||||
safety_settings=safety_settings,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return model_response
|
||||
return _model_response
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content})\n"
|
||||
## LOGGING
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue