mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 04:02:00 +00:00
Merge branch 'main' into feat/litellm_sambanova_usage
This commit is contained in:
commit
ec73b3d066
95 changed files with 206742 additions and 6573 deletions
|
|
@ -506,3 +506,80 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
|
|||
else:
|
||||
for tc in response.completion_message.tool_calls:
|
||||
assert tc.tool_name == "get_object_namespace_list"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
# Tests if the model can handle simple messages like "Hi" or
|
||||
# a message unrelated to one of the tool calls
|
||||
"inference:chat_completion:multi_turn_tool_calling_01",
|
||||
# Tests if the model can do full tool call with responses correctly
|
||||
"inference:chat_completion:multi_turn_tool_calling_02",
|
||||
# Tests if model can generate multiple params and
|
||||
# read outputs correctly
|
||||
"inference:chat_completion:multi_turn_tool_calling_03",
|
||||
# Tests if model can do different tool calls in a seqeunce
|
||||
# and use the information between appropriately
|
||||
"inference:chat_completion:multi_turn_tool_calling_04",
|
||||
# Tests if model can use current date and run multiple tool calls
|
||||
# sequentially and infer using both
|
||||
"inference:chat_completion:multi_turn_tool_calling_05",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_multi_turn_tool_calling(client_with_models, text_model_id, test_case):
|
||||
"""This test tests the model's tool calling loop in various scenarios"""
|
||||
if "llama-4" not in text_model_id.lower():
|
||||
pytest.xfail("Not tested for non-llama4 models yet")
|
||||
|
||||
tc = TestCase(test_case)
|
||||
messages = []
|
||||
|
||||
# keep going until either
|
||||
# 1. we have messages to test in multi-turn
|
||||
# 2. no messages bust last message is tool response
|
||||
while len(tc["messages"]) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
|
||||
# do not take new messages if last message is tool response
|
||||
if len(messages) == 0 or messages[-1]["role"] != "tool":
|
||||
new_messages = tc["messages"].pop(0)
|
||||
messages += new_messages
|
||||
|
||||
# pprint(messages)
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=messages,
|
||||
tools=tc["tools"],
|
||||
stream=False,
|
||||
sampling_params={
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"top_p": 0.9,
|
||||
"temperature": 0.6,
|
||||
}
|
||||
},
|
||||
)
|
||||
op_msg = response.completion_message
|
||||
messages.append(op_msg.model_dump())
|
||||
# pprint(op_msg)
|
||||
|
||||
assert op_msg.role == "assistant"
|
||||
expected = tc["expected"].pop(0)
|
||||
assert len(op_msg.tool_calls) == expected["num_tool_calls"]
|
||||
|
||||
if expected["num_tool_calls"] > 0:
|
||||
assert op_msg.tool_calls[0].tool_name == expected["tool_name"]
|
||||
assert op_msg.tool_calls[0].arguments == expected["tool_arguments"]
|
||||
|
||||
tool_response = tc["tool_responses"].pop(0)
|
||||
messages.append(
|
||||
# Tool Response Message
|
||||
{
|
||||
"role": "tool",
|
||||
"call_id": op_msg.tool_calls[0].call_id,
|
||||
"content": tool_response["response"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
actual_answer = op_msg.content.lower()
|
||||
# pprint(actual_answer)
|
||||
assert expected["answer"] in actual_answer
|
||||
|
|
|
|||
|
|
@ -4,11 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import base64
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_path():
|
||||
|
|
@ -27,7 +31,6 @@ def base64_image_url(base64_image_data, image_path):
|
|||
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
|
||||
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
|
||||
message = {
|
||||
"role": "user",
|
||||
|
|
@ -56,7 +59,99 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
|
|||
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
|
||||
@pytest.fixture
|
||||
def multi_image_data():
|
||||
files = [
|
||||
THIS_DIR / "vision_test_1.jpg",
|
||||
THIS_DIR / "vision_test_2.jpg",
|
||||
THIS_DIR / "vision_test_3.jpg",
|
||||
]
|
||||
encoded_files = []
|
||||
for file in files:
|
||||
with open(file, "rb") as image_file:
|
||||
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
encoded_files.append(base64_data)
|
||||
return encoded_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_image_chat_completion_multiple_images(client_with_models, vision_model_id, multi_image_data, stream):
|
||||
if "llama-4" not in vision_model_id.lower() and "gpt-4o" not in vision_model_id.lower():
|
||||
pytest.skip("Skip for non-llama4, gpt4o models")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": multi_image_data[0],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": multi_image_data[1],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the differences between these images? Where would you assume they would be located?",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=vision_model_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content = ""
|
||||
for chunk in response:
|
||||
message_content += chunk.event.delta.text
|
||||
else:
|
||||
message_content = response.completion_message.content
|
||||
assert len(message_content) > 0
|
||||
assert any(expected in message_content.lower().strip() for expected in {"bedroom"}), message_content
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": message_content}],
|
||||
"stop_reason": "end_of_turn",
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": multi_image_data[2],
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "How about this one?"},
|
||||
],
|
||||
},
|
||||
)
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=vision_model_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content = ""
|
||||
for chunk in response:
|
||||
message_content += chunk.event.delta.text
|
||||
else:
|
||||
message_content = response.completion_message.content
|
||||
assert len(message_content) > 0
|
||||
assert any(expected in message_content.lower().strip() for expected in {"sword", "shield"}), message_content
|
||||
|
||||
|
||||
def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
||||
message = {
|
||||
"role": "user",
|
||||
|
|
|
|||
BIN
tests/integration/inference/vision_test_1.jpg
Normal file
BIN
tests/integration/inference/vision_test_1.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 108 KiB |
BIN
tests/integration/inference/vision_test_2.jpg
Normal file
BIN
tests/integration/inference/vision_test_2.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 148 KiB |
BIN
tests/integration/inference/vision_test_3.jpg
Normal file
BIN
tests/integration/inference/vision_test_3.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 139 KiB |
Loading…
Add table
Add a link
Reference in a new issue