fix(vertex_ai.py): handle nested content dictionary for assistant message

openai allows the assistant content message to also be a list of dictionaries, handle that
This commit is contained in:
Krrish Dholakia 2024-05-24 22:47:34 -07:00
parent 799c63bf8e
commit 281906ff33
2 changed files with 44 additions and 6 deletions

View file

@ -376,17 +376,31 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
assistant_content = [] assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ## ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_text = ( if isinstance(messages[msg_i]["content"], list):
messages[msg_i].get("content") or "" _parts = []
) # either string or none for element in messages[msg_i]["content"]:
if assistant_text: if isinstance(element, dict):
assistant_content.append(PartType(text=assistant_text)) if element["type"] == "text":
if messages[msg_i].get( _part = PartType(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part) # type: ignore
assistant_content.extend(_parts)
elif messages[msg_i].get(
"tool_calls", [] "tool_calls", []
): # support assistant tool invoke convertion ): # support assistant tool invoke convertion
assistant_content.extend( assistant_content.extend(
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"]) convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
) )
else:
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(PartType(text=assistant_text))
msg_i += 1 msg_i += 1
if assistant_content: if assistant_content:

View file

@ -975,3 +975,27 @@ def test_prompt_factory():
translated_messages = _gemini_convert_messages_with_history(messages=messages) translated_messages = _gemini_convert_messages_with_history(messages=messages)
print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages") print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")
def test_prompt_factory_nested():
messages = [
{"role": "user", "content": [{"type": "text", "text": "hi"}]},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Hi! 👋 \n\nHow can I help you today? 😊 \n"}
],
},
{"role": "user", "content": [{"type": "text", "text": "hi 2nd time"}]},
]
translated_messages = _gemini_convert_messages_with_history(messages=messages)
print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")
for message in translated_messages:
assert len(message["parts"]) == 1
assert "text" in message["parts"][0], "Missing 'text' from 'parts'"
assert isinstance(
message["parts"][0]["text"], str
), "'text' value not a string."