feat(internal): add image_url download feature to OpenAIMixin

simplify Ollama inference adapter by -
 - moving image_url download code to OpenAIMixin
 - being a ModelRegistryHelper instead of having one (mypy blocks check_model_availability method assignment)

testing -
 - add unit tests for new download feature
 - add integration tests for openai_chat_completion w/ image_url (close test gap)
This commit is contained in:
Matthew Farrellee 2025-09-22 06:56:56 -04:00
parent e3f77c1004
commit 65c4ffca28
5 changed files with 257 additions and 87 deletions

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, PropertyMock, patch
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from llama_stack.apis.inference import Model
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -27,8 +27,17 @@ class OpenAIMixinImpl(OpenAIMixin):
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin"""
return OpenAIMixinImpl()
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mixin_instance.model_store = mock_model_store
return mixin_instance
@pytest.fixture
@ -181,3 +190,71 @@ class TestOpenAIMixinCacheBehavior:
assert "some-mock-model-id" in mixin._model_cache
assert "another-mock-model-id" in mixin._model_cache
assert "final-mock-model-id" in mixin._model_cache
class TestOpenAIMixinImagePreprocessing:
"""Test cases for image preprocessing functionality"""
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
"""Test that image URLs are converted to base64 when download_images is True"""
mixin.download_images = True
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_called_once_with("http://example.com/image.jpg")
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[0]["type"] == "text"
assert content[1]["type"] == "image_url"
assert content[1]["image_url"]["url"] == ""
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
"""Test that image URLs are not modified when download_images is False"""
mixin.download_images = False # explicitly set to False
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_not_called()
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"