matching openai tool result spec

This commit is contained in:
andrewmjc 2024-07-02 16:57:13 -06:00
parent c4e11e03d7
commit e07b110b47
2 changed files with 17 additions and 4 deletions

View file

@ -1022,16 +1022,17 @@ def convert_to_gemini_tool_call_invoke(
def convert_to_gemini_tool_call_result(
message: dict,
last_message_with_tool_calls: dict|None,
) -> litellm.types.llms.vertex_ai.PartType:
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"role": "tool",
"content": "function result goes here",
},
# NOTE: Function messages have been deprecated
OpenAI message with a function call result looks like:
{
"role": "function",
@ -1040,7 +1041,16 @@ def convert_to_gemini_tool_call_result(
}
"""
content = message.get("content", "")
name = message.get("name", "")
name = ""
# Recover name from last message with tool calls
if last_message_with_tool_calls:
tools = last_message_with_tool_calls.get("tool_calls", [])
msg_tool_call_id = message.get("tool_call_id", None)
for tool in tools:
prev_tool_call_id = tool.get("id", None)
if msg_tool_call_id and prev_tool_call_id and msg_tool_call_id == prev_tool_call_id:
name = tool.get("function", {}).get("name", "")
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template

View file

@ -328,6 +328,8 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
user_message_types = {"user", "system"}
contents: List[ContentType] = []
last_message_with_tool_calls = None
msg_i = 0
try:
while msg_i < len(messages):
@ -383,6 +385,7 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
messages[msg_i]["tool_calls"]
)
)
last_message_with_tool_calls = messages[msg_i]
else:
assistant_text = (
messages[msg_i].get("content") or ""
@ -397,7 +400,7 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
_part = convert_to_gemini_tool_call_result(messages[msg_i])
_part = convert_to_gemini_tool_call_result(messages[msg_i], last_message_with_tool_calls)
contents.append(ContentType(parts=[_part])) # type: ignore
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops