mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
minor fix
This commit is contained in:
parent
31453f3f79
commit
541d0c6f1a
2 changed files with 24 additions and 20 deletions
|
@ -496,48 +496,53 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
|
# Tests if the model can handle simple messages like "Hi" or
|
||||||
|
# a message unrelated to one of the tool calls
|
||||||
"inference:chat_completion:multi_turn_tool_calling_01",
|
"inference:chat_completion:multi_turn_tool_calling_01",
|
||||||
|
# Tests if the model can do full tool call with responses correctly
|
||||||
"inference:chat_completion:multi_turn_tool_calling_02",
|
"inference:chat_completion:multi_turn_tool_calling_02",
|
||||||
|
# Tests if model can generate multiple params and
|
||||||
|
# read outputs correctly
|
||||||
"inference:chat_completion:multi_turn_tool_calling_03",
|
"inference:chat_completion:multi_turn_tool_calling_03",
|
||||||
|
# Tests if model can do different tool calls in a seqeunce
|
||||||
|
# and use the information between appropriately
|
||||||
"inference:chat_completion:multi_turn_tool_calling_04",
|
"inference:chat_completion:multi_turn_tool_calling_04",
|
||||||
|
# Tests if model can use current date and run multiple tool calls
|
||||||
|
# sequentially and infer using both
|
||||||
"inference:chat_completion:multi_turn_tool_calling_05",
|
"inference:chat_completion:multi_turn_tool_calling_05",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_with_tool_calling_loop_non_streaming(client_with_models, text_model_id, test_case):
|
def test_text_chat_completion_with_multi_turn_tool_calling(client_with_models, text_model_id, test_case):
|
||||||
from rich.pretty import pprint
|
"""This test tests the model's tool calling loop in various scenarios"""
|
||||||
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
# keep going until either
|
# keep going until either
|
||||||
# 1. we have messages to test in multi-turn
|
# 1. we have messages to test in multi-turn
|
||||||
# 2. no messages bust last message is tool response
|
# 2. no messages bust last message is tool response
|
||||||
while len(tc["messages"]) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
|
while len(tc["messages"]) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
|
||||||
# do not take new messages if last message is tool response
|
# do not take new messages if last message is tool response
|
||||||
if (
|
if len(messages) == 0 or messages[-1]["role"] != "tool":
|
||||||
len(messages) == 0
|
|
||||||
or (isinstance(messages[-1], dict) and messages[-1]["role"] != "tool")
|
|
||||||
or (not isinstance(messages[-1], dict) and messages[-1].role != "tool")
|
|
||||||
):
|
|
||||||
new_messages = tc["messages"].pop(0)
|
new_messages = tc["messages"].pop(0)
|
||||||
messages += new_messages
|
messages += new_messages
|
||||||
|
|
||||||
pprint(messages)
|
# pprint(messages)
|
||||||
response = client_with_models.inference.chat_completion(
|
response = client_with_models.inference.chat_completion(
|
||||||
model_id=text_model_id,
|
model_id=text_model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tc["tools"],
|
tools=tc["tools"],
|
||||||
stream=False,
|
stream=False,
|
||||||
# sampling_params={
|
sampling_params={
|
||||||
# "strategy": {
|
"strategy": {
|
||||||
# "type": "top_p",
|
"type": "top_p",
|
||||||
# "top_p": 0.9,
|
"top_p": 0.9,
|
||||||
# "temperature": 0.6,
|
"temperature": 0.6,
|
||||||
# }
|
}
|
||||||
# },
|
},
|
||||||
)
|
)
|
||||||
op_msg = response.completion_message
|
op_msg = response.completion_message
|
||||||
messages.append(op_msg.model_dump())
|
messages.append(op_msg.model_dump())
|
||||||
pprint(op_msg)
|
# pprint(op_msg)
|
||||||
|
|
||||||
assert op_msg.role == "assistant"
|
assert op_msg.role == "assistant"
|
||||||
expected = tc["expected"].pop(0)
|
expected = tc["expected"].pop(0)
|
||||||
|
@ -558,5 +563,5 @@ def test_text_chat_completion_with_tool_calling_loop_non_streaming(client_with_m
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
actual_answer = op_msg.content.lower()
|
actual_answer = op_msg.content.lower()
|
||||||
pprint(actual_answer)
|
# pprint(actual_answer)
|
||||||
assert expected["answer"] in actual_answer
|
assert expected["answer"] in actual_answer
|
||||||
|
|
|
@ -94,8 +94,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"expected": {
|
"expected": {
|
||||||
"num_tool_calls": 1,
|
"location": "San Francisco, CA"
|
||||||
"expected": "San Francisco, CA"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue