diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 6dff1be24..81cb0adf3 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import base64 -import os +import pathlib import pytest from pydantic import BaseModel @@ -57,13 +57,20 @@ def get_weather_tool_definition(): @pytest.fixture -def base64_image_url(): - image_path = os.path.join(os.path.dirname(__file__), "dog.png") - with open(image_path, "rb") as image_file: - # Convert the image to base64 - base64_string = base64.b64encode(image_file.read()).decode("utf-8") - base64_url = f"data:image/png;base64,{base64_string}" - return base64_url +def image_path(): + return pathlib.Path(__file__).parent / "dog.png" + + +@pytest.fixture +def base64_image_data(image_path): + # Convert the image to base64 + return base64.b64encode(image_path.read_bytes()).decode("utf-8") + + +@pytest.fixture +def base64_image_url(base64_image_data, image_path): + # suffix includes the ., so we remove it + return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" def test_text_completion_non_streaming(llama_stack_client, text_model_id): @@ -371,6 +378,33 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) +def test_image_chat_completion_base64_data( + llama_stack_client, vision_model_id, base64_image_data +): + message = { + "role": "user", + "content": [ + { + "type": "image", + "image": { + "data": base64_image_data, + }, + }, + { + "type": "text", + "text": "Describe what is in this image.", + }, + ], + } + response = llama_stack_client.inference.chat_completion( + model_id=vision_model_id, + messages=[message], + stream=False, + ) + message_content = response.completion_message.content.lower().strip() + assert len(message_content) > 0 + + def test_image_chat_completion_base64_url( llama_stack_client, vision_model_id, base64_image_url ):