feat(internal): add image_url download feature to OpenAIMixin (#3516)

# What does this PR do?

simplify Ollama inference adapter by -
 - moving image_url download code to OpenAIMixin
- being a ModelRegistryHelper instead of having one (mypy blocks
check_model_availability method assignment)

## Test Plan

 - add unit tests for new download feature
- add integration tests for openai_chat_completion w/ image_url (close
test gap)
This commit is contained in:
Matthew Farrellee 2025-09-26 17:32:16 -04:00 committed by GitHub
parent 4487b88ffe
commit b48d5cfed7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 253 additions and 84 deletions

View file

@ -6,8 +6,7 @@
import asyncio
import base64
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncGenerator
from typing import Any
from ollama import AsyncClient as AsyncOllamaClient
@ -33,10 +32,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -62,7 +57,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
@ -75,7 +69,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media,
convert_image_content_to_url,
interleaved_content_as_str,
localize_image_content,
request_has_media,
)
@ -84,6 +77,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
InferenceProvider,
ModelsProtocolPrivate,
):
@ -129,6 +123,8 @@ class OllamaInferenceAdapter(
],
)
self.config = config
# Ollama does not support image urls, so we need to download the image and convert it to base64
self.download_images = True
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
@ -173,9 +169,6 @@ class OllamaInferenceAdapter(
async def shutdown(self) -> None:
self._clients.clear()
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
@ -403,75 +396,6 @@ class OllamaInferenceAdapter(
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self._get_model(model)
# Ollama does not support image urls, so we need to download the image and convert it to base64
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
if isinstance(m.content, list):
for c in m.content:
if c.type == "image_url" and c.image_url and c.image_url.url:
localize_result = await localize_image_content(c.image_url.url)
if localize_result is None:
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
content, format = localize_result
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
return m
messages = [await _convert_message(m) for m in messages]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await OpenAIMixin.openai_chat_completion(self, **params)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
@ -26,6 +27,7 @@ from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils")
@ -51,6 +53,10 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
# This is useful for providers that do not return a unique id in the response.
overwrite_completion_id: bool = False
# Allow subclasses to control whether to download images and convert to base64
# for providers that require base64 encoded images instead of URLs.
download_images: bool = False
# Embedding model metadata for this provider
# Can be set by subclasses or instances to provide embedding models
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
@ -239,6 +245,24 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
"""
Direct OpenAI chat completion API call.
"""
if self.download_images:
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
if isinstance(m.content, list):
for c in m.content:
if c.type == "image_url" and c.image_url and c.image_url.url and "http" in c.image_url.url:
localize_result = await localize_image_content(c.image_url.url)
if localize_result is None:
raise ValueError(
f"Failed to localize image content from {c.image_url.url[:42]}{'...' if len(c.image_url.url) > 42 else ''}"
)
content, format = localize_result
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
# else it's a string and we don't need to modify it
return m
messages = [await _localize_image_url(m) for m in messages]
resp = await self.client.chat.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import pathlib
import pytest
@pytest.fixture
def image_path():
return pathlib.Path(__file__).parent / "dog.png"
@pytest.fixture
def base64_image_data(image_path):
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
async def test_openai_chat_completion_image_url(openai_client, vision_model_id):
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
},
},
{
"type": "text",
"text": "Describe what is in this image.",
},
],
}
response = openai_client.chat.completions.create(
model=vision_model_id,
messages=[message],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert len(message_content) > 0
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
async def test_openai_chat_completion_image_data(openai_client, vision_model_id, base64_image_data):
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image_data}",
},
},
{
"type": "text",
"text": "Describe what is in this image.",
},
],
}
response = openai_client.chat.completions.create(
model=vision_model_id,
messages=[message],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert len(message_content) > 0
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})

File diff suppressed because one or more lines are too long

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, PropertyMock, patch
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from llama_stack.apis.inference import Model
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -43,8 +43,17 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin"""
return OpenAIMixinImpl()
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mixin_instance.model_store = mock_model_store
return mixin_instance
@pytest.fixture
@ -205,6 +214,74 @@ class TestOpenAIMixinCacheBehavior:
assert "final-mock-model-id" in mixin._model_cache
class TestOpenAIMixinImagePreprocessing:
"""Test cases for image preprocessing functionality"""
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
"""Test that image URLs are converted to base64 when download_images is True"""
mixin.download_images = True
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_called_once_with("http://example.com/image.jpg")
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[0]["type"] == "text"
assert content[1]["type"] == "image_url"
assert content[1]["image_url"]["url"] == "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh"
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
"""Test that image URLs are not modified when download_images is False"""
mixin.download_images = False # explicitly set to False
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_not_called()
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
class TestOpenAIMixinEmbeddingModelMetadata:
"""Test cases for embedding_model_metadata attribute functionality"""