feat: introduce llama4 support (#1877)

As title says. Details in README, elsewhere.
This commit is contained in:
Ashwin Bharambe 2025-04-05 11:53:35 -07:00 committed by GitHub
parent 23a99a4b22
commit b8f1561956
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
61 changed files with 205222 additions and 6439 deletions

View file

@ -4,11 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import pathlib
from pathlib import Path
import pytest
THIS_DIR = Path(__file__).parent
@pytest.fixture
def image_path():
@ -27,7 +31,6 @@ def base64_image_url(base64_image_data, image_path):
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
message = {
"role": "user",
@ -56,7 +59,99 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
@pytest.fixture
def multi_image_data():
files = [
THIS_DIR / "vision_test_1.jpg",
THIS_DIR / "vision_test_2.jpg",
THIS_DIR / "vision_test_3.jpg",
]
encoded_files = []
for file in files:
with open(file, "rb") as image_file:
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
encoded_files.append(base64_data)
return encoded_files
@pytest.mark.parametrize("stream", [True, False])
def test_image_chat_completion_multiple_images(client_with_models, vision_model_id, multi_image_data, stream):
if "llama-4" not in vision_model_id.lower() and "gpt-4o" not in vision_model_id.lower():
pytest.skip("Skip for non-llama4, gpt4o models")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": {
"data": multi_image_data[0],
},
},
{
"type": "image",
"image": {
"data": multi_image_data[1],
},
},
{
"type": "text",
"text": "What are the differences between these images? Where would you assume they would be located?",
},
],
},
]
response = client_with_models.inference.chat_completion(
model_id=vision_model_id,
messages=messages,
stream=stream,
)
if stream:
message_content = ""
for chunk in response:
message_content += chunk.event.delta.text
else:
message_content = response.completion_message.content
assert len(message_content) > 0
assert any(expected in message_content.lower().strip() for expected in {"bedroom"}), message_content
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": message_content}],
"stop_reason": "end_of_turn",
}
)
messages.append(
{
"role": "user",
"content": [
{
"type": "image",
"image": {
"data": multi_image_data[2],
},
},
{"type": "text", "text": "How about this one?"},
],
},
)
response = client_with_models.inference.chat_completion(
model_id=vision_model_id,
messages=messages,
stream=stream,
)
if stream:
message_content = ""
for chunk in response:
message_content += chunk.event.delta.text
else:
message_content = response.completion_message.content
assert len(message_content) > 0
assert any(expected in message_content.lower().strip() for expected in {"sword", "shield"}), message_content
def test_image_chat_completion_streaming(client_with_models, vision_model_id):
message = {
"role": "user",