mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
test: add multi_image test (#1972)
# What does this PR do? ## Test Plan pytest tests/verifications/openai_api/test_chat_completion.py --provider openai -k 'test_chat_multiple_images'
This commit is contained in:
parent
2976b5d992
commit
0ed41aafbf
16 changed files with 2416 additions and 1585 deletions
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
@ -19,6 +21,8 @@ from tests.verifications.openai_api.fixtures.load import load_test_cases
|
|||
|
||||
chat_completion_test_cases = load_test_cases("chat_completion")
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def case_id_generator(case):
|
||||
"""Generate a test ID from the case's 'case_id' field, or use a default."""
|
||||
|
@ -71,6 +75,21 @@ def get_base_test_name(request):
|
|||
return request.node.originalname
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_image_data():
|
||||
files = [
|
||||
THIS_DIR / "fixtures/images/vision_test_1.jpg",
|
||||
THIS_DIR / "fixtures/images/vision_test_2.jpg",
|
||||
THIS_DIR / "fixtures/images/vision_test_3.jpg",
|
||||
]
|
||||
encoded_files = []
|
||||
for file in files:
|
||||
with open(file, "rb") as image_file:
|
||||
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
encoded_files.append(f"data:image/jpeg;base64,{base64_data}")
|
||||
return encoded_files
|
||||
|
||||
|
||||
# --- Test Functions ---
|
||||
|
||||
|
||||
|
@ -533,6 +552,86 @@ def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, p
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [False, True], ids=["stream=False", "stream=True"])
|
||||
def test_chat_multi_turn_multiple_images(
|
||||
request, openai_client, model, provider, verification_config, multi_image_data, stream
|
||||
):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
messages_turn1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": multi_image_data[0],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": multi_image_data[1],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What furniture is in the first image that is not in the second image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# First API call
|
||||
response1 = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages_turn1,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content1 = ""
|
||||
for chunk in response1:
|
||||
message_content1 += chunk.choices[0].delta.content or ""
|
||||
else:
|
||||
message_content1 = response1.choices[0].message.content
|
||||
assert len(message_content1) > 0
|
||||
assert any(expected in message_content1.lower().strip() for expected in {"chair", "table"}), message_content1
|
||||
|
||||
# Prepare messages for the second turn
|
||||
messages_turn2 = messages_turn1 + [
|
||||
{"role": "assistant", "content": message_content1},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": multi_image_data[2],
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What is in this image that is also in the first image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Second API call
|
||||
response2 = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages_turn2,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content2 = ""
|
||||
for chunk in response2:
|
||||
message_content2 += chunk.choices[0].delta.content or ""
|
||||
else:
|
||||
message_content2 = response2.choices[0].message.content
|
||||
assert len(message_content2) > 0
|
||||
assert any(expected in message_content2.lower().strip() for expected in {"bed"}), message_content2
|
||||
|
||||
|
||||
# --- Helper functions (structured output validation) ---
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue