feat(internal): add image_url download feature to OpenAIMixin (#3516)

# What does this PR do?

simplify Ollama inference adapter by -
 - moving image_url download code to OpenAIMixin
- being a ModelRegistryHelper instead of having one (mypy blocks
check_model_availability method assignment)

## Test Plan

 - add unit tests for new download feature
- add integration tests for openai_chat_completion w/ image_url (close
test gap)
This commit is contained in:
Matthew Farrellee 2025-09-26 17:32:16 -04:00 committed by GitHub
parent 4487b88ffe
commit b48d5cfed7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 253 additions and 84 deletions

View file

@ -6,8 +6,7 @@
import asyncio import asyncio
import base64 from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
from ollama import AsyncClient as AsyncOllamaClient from ollama import AsyncClient as AsyncOllamaClient
@ -33,10 +32,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation, TextTruncation,
@ -62,7 +57,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
@ -75,7 +69,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media, content_has_media,
convert_image_content_to_url, convert_image_content_to_url,
interleaved_content_as_str, interleaved_content_as_str,
localize_image_content,
request_has_media, request_has_media,
) )
@ -84,6 +77,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter( class OllamaInferenceAdapter(
OpenAIMixin, OpenAIMixin,
ModelRegistryHelper,
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
@ -129,6 +123,8 @@ class OllamaInferenceAdapter(
], ],
) )
self.config = config self.config = config
# Ollama does not support image urls, so we need to download the image and convert it to base64
self.download_images = True
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {} self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property @property
@ -173,9 +169,6 @@ class OllamaInferenceAdapter(
async def shutdown(self) -> None: async def shutdown(self) -> None:
self._clients.clear() self._clients.clear()
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model: async def _get_model(self, model_id: str) -> Model:
if not self.model_store: if not self.model_store:
raise ValueError("Model store not set") raise ValueError("Model store not set")
@ -403,75 +396,6 @@ class OllamaInferenceAdapter(
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys())) raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self._get_model(model)
# Ollama does not support image urls, so we need to download the image and convert it to base64
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
if isinstance(m.content, list):
for c in m.content:
if c.type == "image_url" and c.image_url and c.image_url.url:
localize_result = await localize_image_content(c.image_url.url)
if localize_result is None:
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
content, format = localize_result
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
return m
messages = [await _convert_message(m) for m in messages]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await OpenAIMixin.openai_chat_completion(self, **params)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
@ -26,6 +27,7 @@ from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils") logger = get_logger(name=__name__, category="providers::utils")
@ -51,6 +53,10 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
# This is useful for providers that do not return a unique id in the response. # This is useful for providers that do not return a unique id in the response.
overwrite_completion_id: bool = False overwrite_completion_id: bool = False
# Allow subclasses to control whether to download images and convert to base64
# for providers that require base64 encoded images instead of URLs.
download_images: bool = False
# Embedding model metadata for this provider # Embedding model metadata for this provider
# Can be set by subclasses or instances to provide embedding models # Can be set by subclasses or instances to provide embedding models
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
@ -239,6 +245,24 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
""" """
Direct OpenAI chat completion API call. Direct OpenAI chat completion API call.
""" """
if self.download_images:
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
if isinstance(m.content, list):
for c in m.content:
if c.type == "image_url" and c.image_url and c.image_url.url and "http" in c.image_url.url:
localize_result = await localize_image_content(c.image_url.url)
if localize_result is None:
raise ValueError(
f"Failed to localize image content from {c.image_url.url[:42]}{'...' if len(c.image_url.url) > 42 else ''}"
)
content, format = localize_result
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
# else it's a string and we don't need to modify it
return m
messages = [await _localize_image_url(m) for m in messages]
resp = await self.client.chat.completions.create( resp = await self.client.chat.completions.create(
**await prepare_openai_completion_params( **await prepare_openai_completion_params(
model=await self._get_provider_model_id(model), model=await self._get_provider_model_id(model),

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import pathlib
import pytest
@pytest.fixture
def image_path():
return pathlib.Path(__file__).parent / "dog.png"
@pytest.fixture
def base64_image_data(image_path):
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
async def test_openai_chat_completion_image_url(openai_client, vision_model_id):
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
},
},
{
"type": "text",
"text": "Describe what is in this image.",
},
],
}
response = openai_client.chat.completions.create(
model=vision_model_id,
messages=[message],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert len(message_content) > 0
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
async def test_openai_chat_completion_image_data(openai_client, vision_model_id, base64_image_data):
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image_data}",
},
},
{
"type": "text",
"text": "Describe what is in this image.",
},
],
}
response = openai_client.chat.completions.create(
model=vision_model_id,
messages=[message],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert len(message_content) > 0
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})

File diff suppressed because one or more lines are too long

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest import pytest
from llama_stack.apis.inference import Model from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -43,8 +43,17 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
@pytest.fixture @pytest.fixture
def mixin(): def mixin():
"""Create a test instance of OpenAIMixin""" """Create a test instance of OpenAIMixin with mocked model_store"""
return OpenAIMixinImpl() mixin_instance = OpenAIMixinImpl()
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mixin_instance.model_store = mock_model_store
return mixin_instance
@pytest.fixture @pytest.fixture
@ -205,6 +214,74 @@ class TestOpenAIMixinCacheBehavior:
assert "final-mock-model-id" in mixin._model_cache assert "final-mock-model-id" in mixin._model_cache
class TestOpenAIMixinImagePreprocessing:
"""Test cases for image preprocessing functionality"""
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
"""Test that image URLs are converted to base64 when download_images is True"""
mixin.download_images = True
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_called_once_with("http://example.com/image.jpg")
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[0]["type"] == "text"
assert content[1]["type"] == "image_url"
assert content[1]["image_url"]["url"] == ""
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
"""Test that image URLs are not modified when download_images is False"""
mixin.download_images = False # explicitly set to False
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_not_called()
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
class TestOpenAIMixinEmbeddingModelMetadata: class TestOpenAIMixinEmbeddingModelMetadata:
"""Test cases for embedding_model_metadata attribute functionality""" """Test cases for embedding_model_metadata attribute functionality"""