matching openai tool result spec

This commit is contained in:
andrewmjc 2024-07-02 16:57:13 -06:00
parent c4e11e03d7
commit e07b110b47
2 changed files with 17 additions and 4 deletions

View file

@ -1022,16 +1022,17 @@ def convert_to_gemini_tool_call_invoke(
def convert_to_gemini_tool_call_result( def convert_to_gemini_tool_call_result(
message: dict, message: dict,
last_message_with_tool_calls: dict|None,
) -> litellm.types.llms.vertex_ai.PartType: ) -> litellm.types.llms.vertex_ai.PartType:
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:
{ {
"tool_call_id": "tool_1", "tool_call_id": "tool_1",
"role": "tool", "role": "tool",
"name": "get_current_weather",
"content": "function result goes here", "content": "function result goes here",
}, },
# NOTE: Function messages have been deprecated
OpenAI message with a function call result looks like: OpenAI message with a function call result looks like:
{ {
"role": "function", "role": "function",
@ -1040,7 +1041,16 @@ def convert_to_gemini_tool_call_result(
} }
""" """
content = message.get("content", "") content = message.get("content", "")
name = message.get("name", "") name = ""
# Recover name from last message with tool calls
if last_message_with_tool_calls:
tools = last_message_with_tool_calls.get("tool_calls", [])
msg_tool_call_id = message.get("tool_call_id", None)
for tool in tools:
prev_tool_call_id = tool.get("id", None)
if msg_tool_call_id and prev_tool_call_id and msg_tool_call_id == prev_tool_call_id:
name = tool.get("function", {}).get("name", "")
# We can't determine from openai message format whether it's a successful or # We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template # error call result so default to the successful result template

View file

@ -328,6 +328,8 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
user_message_types = {"user", "system"} user_message_types = {"user", "system"}
contents: List[ContentType] = [] contents: List[ContentType] = []
last_message_with_tool_calls = None
msg_i = 0 msg_i = 0
try: try:
while msg_i < len(messages): while msg_i < len(messages):
@ -383,6 +385,7 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
messages[msg_i]["tool_calls"] messages[msg_i]["tool_calls"]
) )
) )
last_message_with_tool_calls = messages[msg_i]
else: else:
assistant_text = ( assistant_text = (
messages[msg_i].get("content") or "" messages[msg_i].get("content") or ""
@ -397,7 +400,7 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
## APPEND TOOL CALL MESSAGES ## ## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool": if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
_part = convert_to_gemini_tool_call_result(messages[msg_i]) _part = convert_to_gemini_tool_call_result(messages[msg_i], last_message_with_tool_calls)
contents.append(ContentType(parts=[_part])) # type: ignore contents.append(ContentType(parts=[_part])) # type: ignore
msg_i += 1 msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops if msg_i == init_msg_i: # prevent infinite loops