diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 7ec21e8bb..c9e691c00 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2195,7 +2195,7 @@ def _convert_to_bedrock_tool_call_invoke( def _convert_to_bedrock_tool_call_result( message: dict, -) -> BedrockMessageBlock: +) -> BedrockContentBlock: """ OpenAI message with a tool result looks like: { @@ -2247,7 +2247,7 @@ def _convert_to_bedrock_tool_call_result( ) content_block = BedrockContentBlock(toolResult=tool_result) - return BedrockMessageBlock(role="user", content=[content_block]) + return content_block def _bedrock_converse_messages_pt( @@ -2289,6 +2289,12 @@ def _bedrock_converse_messages_pt( msg_i += 1 + ## MERGE CONSECUTIVE TOOL CALL MESSAGES ## + while msg_i < len(messages) and messages[msg_i]["role"] == "tool": + tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i]) + + user_content.append(tool_call_result) + msg_i += 1 if user_content: contents.append(BedrockMessageBlock(role="user", content=user_content)) assistant_content: List[BedrockContentBlock] = [] @@ -2332,11 +2338,6 @@ def _bedrock_converse_messages_pt( BedrockMessageBlock(role="assistant", content=assistant_content) ) - ## APPEND TOOL CALL MESSAGES ## - if msg_i < len(messages) and messages[msg_i]["role"] == "tool": - tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i]) - contents.append(tool_call_result) - msg_i += 1 if msg_i == init_msg_i: # prevent infinite loops raise litellm.BadRequestError( message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}", diff --git a/litellm/tests/test_function_calling.py b/litellm/tests/test_function_calling.py index 6e4e9d3e8..5f97dbf87 100644 --- a/litellm/tests/test_function_calling.py +++ b/litellm/tests/test_function_calling.py @@ -1,18 +1,20 @@ -import sys, os +import os +import sys import traceback + from dotenv import load_dotenv load_dotenv() -import os, io +import io +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest + import litellm -from litellm import embedding, completion, completion_cost, Timeout -from litellm import RateLimitError -import pytest +from litellm import RateLimitError, Timeout, completion, completion_cost, embedding litellm.num_retries = 0 litellm.cache = None @@ -41,7 +43,14 @@ def get_current_weather(location, unit="fahrenheit"): # In production, this could be your backend API or an external API @pytest.mark.parametrize( - "model", ["gpt-3.5-turbo-1106", "mistral/mistral-large-latest"] + "model", + [ + "gpt-3.5-turbo-1106", + "mistral/mistral-large-latest", + "claude-3-haiku-20240307", + "gemini/gemini-1.5-pro", + "anthropic.claude-3-sonnet-20240229-v1:0", + ], ) def test_parallel_function_call(model): try: @@ -124,7 +133,12 @@ def test_parallel_function_call(model): ) # extend conversation with function response print(f"messages: {messages}") second_response = litellm.completion( - model=model, messages=messages, temperature=0.2, seed=22 + model=model, + messages=messages, + temperature=0.2, + seed=22, + tools=tools, + drop_params=True, ) # get a new response from the model where it can see the function response print("second response\n", second_response) except Exception as e: diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py index 93e92a792..81339e831 100644 --- a/litellm/tests/test_prompt_factory.py +++ b/litellm/tests/test_prompt_factory.py @@ -313,3 +313,78 @@ def test_anthropic_cache_controls_pt(): assert msg["content"][0]["cache_control"] == {"type": "ephemeral"} print("translated_messages: ", translated_messages) + + +@pytest.mark.parametrize("provider", ["bedrock", "anthropic"]) +def test_bedrock_parallel_tool_calling_pt(provider): + """ + Make sure parallel tool call blocks are merged correctly - https://github.com/BerriAI/litellm/issues/5277 + """ + from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt + from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message + + messages = [ + { + "role": "user", + "content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses", + }, + Message( + content="Here are the current weather conditions for San Francisco, Tokyo, and Paris:", + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + index=1, + function=Function( + arguments='{"city": "New York"}', + name="get_current_weather", + ), + id="tooluse_XcqEBfm8R-2YVaPhDUHsPQ", + type="function", + ), + ChatCompletionMessageToolCall( + index=2, + function=Function( + arguments='{"city": "London"}', + name="get_current_weather", + ), + id="tooluse_VB9nk7UGRniVzGcaj6xrAQ", + type="function", + ), + ], + function_call=None, + ), + { + "tool_call_id": "tooluse_XcqEBfm8R-2YVaPhDUHsPQ", + "role": "tool", + "name": "get_current_weather", + "content": "25 degrees celsius.", + }, + { + "tool_call_id": "tooluse_VB9nk7UGRniVzGcaj6xrAQ", + "role": "tool", + "name": "get_current_weather", + "content": "28 degrees celsius.", + }, + ] + + if provider == "bedrock": + translated_messages = _bedrock_converse_messages_pt( + messages=messages, + model="anthropic.claude-3-sonnet-20240229-v1:0", + llm_provider="bedrock", + ) + else: + translated_messages = anthropic_messages_pt( + messages=messages, + model="claude-3-sonnet-20240229-v1:0", + llm_provider=provider, + ) + print(translated_messages) + + number_of_messages = len(translated_messages) + + # assert last 2 messages are not the same role + assert ( + translated_messages[number_of_messages - 1]["role"] + != translated_messages[number_of_messages - 2]["role"] + )