forked from phoenix/litellm-mirror
Merge pull request #3124 from elisalimli/bugfix/add-missing-tool-calls-mistral-messages
Add missing tool_calls and name to messages
This commit is contained in:
commit
8d2e411df6
3 changed files with 134 additions and 0 deletions
|
@ -145,6 +145,12 @@ def mistral_api_pt(messages):
|
|||
elif isinstance(m["content"], str):
|
||||
texts = m["content"]
|
||||
new_m = {"role": m["role"], "content": texts}
|
||||
|
||||
if new_m["role"] == "tool" and m.get("name"):
|
||||
new_m["name"] = m["name"]
|
||||
if m.get("tool_calls"):
|
||||
new_m["tool_calls"] = m["tool_calls"]
|
||||
|
||||
new_messages.append(new_m)
|
||||
return new_messages
|
||||
|
||||
|
|
|
@ -484,6 +484,76 @@ def test_completion_mistral_api():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_mistral_api_mistral_large_function_call():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="mistral/mistral-large-latest",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||
assert isinstance(
|
||||
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||
)
|
||||
|
||||
messages.append(
|
||||
response.choices[0].message.model_dump()
|
||||
) # Add assistant tool invokes
|
||||
tool_result = (
|
||||
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
|
||||
)
|
||||
# Add user submitted tool results in the OpenAI format
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"role": "tool",
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
# In the second response, Mistral should deduce answer from tool results
|
||||
second_response = completion(
|
||||
model="mistral/mistral-large-latest",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
print(second_response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Since we already test mistral/mistral-tiny in test_completion_mistral_api. This is only for locally verifying azure mistral works"
|
||||
)
|
||||
|
|
|
@ -589,6 +589,64 @@ def test_completion_mistral_api_stream():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_mistral_api_mistral_large_function_call_with_streaming():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in fahrenheit?",
|
||||
}
|
||||
]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="mistral/mistral-large-latest",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
idx = 0
|
||||
for chunk in response:
|
||||
print(f"chunk in response: {chunk}")
|
||||
if idx == 0:
|
||||
assert (
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
assert isinstance(
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||
)
|
||||
validate_first_streaming_function_calling_chunk(chunk=chunk)
|
||||
elif idx == 1 and chunk.choices[0].finish_reason is None:
|
||||
validate_second_streaming_function_calling_chunk(chunk=chunk)
|
||||
elif chunk.choices[0].finish_reason is not None: # last chunk
|
||||
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
||||
idx += 1
|
||||
# raise Exception("it worked!")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_mistral_api_stream()
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue