diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 38478d931..305400f4a 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -145,6 +145,12 @@ def mistral_api_pt(messages): elif isinstance(m["content"], str): texts = m["content"] new_m = {"role": m["role"], "content": texts} + + if new_m["role"] == "tool" and m.get("name"): + new_m["name"] = m["name"] + if m.get("tool_calls"): + new_m["tool_calls"] = m["tool_calls"] + new_messages.append(new_m) return new_messages diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 09053cf17..61ea12aa1 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -484,6 +484,76 @@ def test_completion_mistral_api(): pytest.fail(f"Error occurred: {e}") +def test_completion_mistral_api_mistral_large_function_call(): + litellm.set_verbose = True + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [ + { + "role": "user", + "content": "What's the weather like in Boston today in Fahrenheit?", + } + ] + try: + # test without max tokens + response = completion( + model="mistral/mistral-large-latest", + messages=messages, + tools=tools, + tool_choice="auto", + ) + # Add any assertions, here to check response args + print(response) + assert isinstance(response.choices[0].message.tool_calls[0].function.name, str) + assert isinstance( + response.choices[0].message.tool_calls[0].function.arguments, str + ) + + messages.append( + response.choices[0].message.model_dump() + ) # Add assistant tool invokes + tool_result = ( + '{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}' + ) + # Add user submitted tool results in the OpenAI format + messages.append( + { + "tool_call_id": response.choices[0].message.tool_calls[0].id, + "role": "tool", + "name": response.choices[0].message.tool_calls[0].function.name, + "content": tool_result, + } + ) + # In the second response, Mistral should deduce answer from tool results + second_response = completion( + model="mistral/mistral-large-latest", + messages=messages, + tools=tools, + tool_choice="auto", + ) + print(second_response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.skip( reason="Since we already test mistral/mistral-tiny in test_completion_mistral_api. This is only for locally verifying azure mistral works" ) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index ea2f3fcb7..914eac7e0 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -589,6 +589,64 @@ def test_completion_mistral_api_stream(): pytest.fail(f"Error occurred: {e}") +def test_completion_mistral_api_mistral_large_function_call_with_streaming(): + litellm.set_verbose = True + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [ + { + "role": "user", + "content": "What's the weather like in Boston today in fahrenheit?", + } + ] + try: + # test without max tokens + response = completion( + model="mistral/mistral-large-latest", + messages=messages, + tools=tools, + tool_choice="auto", + stream=True, + ) + idx = 0 + for chunk in response: + print(f"chunk in response: {chunk}") + if idx == 0: + assert ( + chunk.choices[0].delta.tool_calls[0].function.arguments is not None + ) + assert isinstance( + chunk.choices[0].delta.tool_calls[0].function.arguments, str + ) + validate_first_streaming_function_calling_chunk(chunk=chunk) + elif idx == 1 and chunk.choices[0].finish_reason is None: + validate_second_streaming_function_calling_chunk(chunk=chunk) + elif chunk.choices[0].finish_reason is not None: # last chunk + validate_final_streaming_function_calling_chunk(chunk=chunk) + idx += 1 + # raise Exception("it worked!") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # test_completion_mistral_api_stream()