Merge pull request #3124 from elisalimli/bugfix/add-missing-tool-calls-mistral-messages

Add missing tool_calls and name to messages
This commit is contained in:
Krish Dholakia 2024-04-23 17:25:12 -07:00 committed by GitHub
commit 8d2e411df6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 134 additions and 0 deletions

View file

@ -145,6 +145,12 @@ def mistral_api_pt(messages):
elif isinstance(m["content"], str):
texts = m["content"]
new_m = {"role": m["role"], "content": texts}
if new_m["role"] == "tool" and m.get("name"):
new_m["name"] = m["name"]
if m.get("tool_calls"):
new_m["tool_calls"] = m["tool_calls"]
new_messages.append(new_m)
return new_messages

View file

@ -484,6 +484,76 @@ def test_completion_mistral_api():
pytest.fail(f"Error occurred: {e}")
def test_completion_mistral_api_mistral_large_function_call():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="mistral/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
# In the second response, Mistral should deduce answer from tool results
second_response = completion(
model="mistral/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
)
print(second_response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(
reason="Since we already test mistral/mistral-tiny in test_completion_mistral_api. This is only for locally verifying azure mistral works"
)

View file

@ -589,6 +589,64 @@ def test_completion_mistral_api_stream():
pytest.fail(f"Error occurred: {e}")
def test_completion_mistral_api_mistral_large_function_call_with_streaming():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="mistral/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
)
idx = 0
for chunk in response:
print(f"chunk in response: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1 and chunk.choices[0].finish_reason is None:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
# raise Exception("it worked!")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_mistral_api_stream()