mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat(internal): add image_url download feature to OpenAIMixin
simplify Ollama inference adapter by - - moving image_url download code to OpenAIMixin - being a ModelRegistryHelper instead of having one (mypy blocks check_model_availability method assignment) testing - - add unit tests for new download feature - add integration tests for openai_chat_completion w/ image_url (close test gap)
This commit is contained in:
parent
e3f77c1004
commit
65c4ffca28
5 changed files with 257 additions and 87 deletions
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
@ -27,8 +27,17 @@ class OpenAIMixinImpl(OpenAIMixin):
|
|||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
"""Create a test instance of OpenAIMixin"""
|
||||
return OpenAIMixinImpl()
|
||||
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||
mixin_instance = OpenAIMixinImpl()
|
||||
|
||||
# just enough to satisfy _get_provider_model_id calls
|
||||
mock_model_store = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||
mixin_instance.model_store = mock_model_store
|
||||
|
||||
return mixin_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -181,3 +190,71 @@ class TestOpenAIMixinCacheBehavior:
|
|||
assert "some-mock-model-id" in mixin._model_cache
|
||||
assert "another-mock-model-id" in mixin._model_cache
|
||||
assert "final-mock-model-id" in mixin._model_cache
|
||||
|
||||
|
||||
class TestOpenAIMixinImagePreprocessing:
|
||||
"""Test cases for image preprocessing functionality"""
|
||||
|
||||
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
|
||||
"""Test that image URLs are converted to base64 when download_images is True"""
|
||||
mixin.download_images = True
|
||||
|
||||
message = OpenAIUserMessageParam(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||
],
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
||||
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
|
||||
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
||||
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
processed_messages = call_args[1]["messages"]
|
||||
assert len(processed_messages) == 1
|
||||
content = processed_messages[0]["content"]
|
||||
assert len(content) == 2
|
||||
assert content[0]["type"] == "text"
|
||||
assert content[1]["type"] == "image_url"
|
||||
assert content[1]["image_url"]["url"] == "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh"
|
||||
|
||||
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
|
||||
"""Test that image URLs are not modified when download_images is False"""
|
||||
mixin.download_images = False # explicitly set to False
|
||||
|
||||
message = OpenAIUserMessageParam(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||
],
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
|
||||
mock_localize.assert_not_called()
|
||||
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
processed_messages = call_args[1]["messages"]
|
||||
assert len(processed_messages) == 1
|
||||
content = processed_messages[0]["content"]
|
||||
assert len(content) == 2
|
||||
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue