mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
feat(internal): add image_url download feature to OpenAIMixin
simplify Ollama inference adapter by - - moving image_url download code to OpenAIMixin - being a ModelRegistryHelper instead of having one (mypy blocks check_model_availability method assignment) testing - - add unit tests for new download feature - add integration tests for openai_chat_completion w/ image_url (close test gap)
This commit is contained in:
parent
e3f77c1004
commit
65c4ffca28
5 changed files with 257 additions and 87 deletions
|
@ -6,8 +6,7 @@
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
from collections.abc import AsyncGenerator
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ollama import AsyncClient as AsyncOllamaClient
|
from ollama import AsyncClient as AsyncOllamaClient
|
||||||
|
@ -33,10 +32,6 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -60,7 +55,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
prepare_openai_completion_params,
|
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
|
@ -73,7 +67,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
content_has_media,
|
content_has_media,
|
||||||
convert_image_content_to_url,
|
convert_image_content_to_url,
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
localize_image_content,
|
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -84,6 +77,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
|
||||||
|
|
||||||
class OllamaInferenceAdapter(
|
class OllamaInferenceAdapter(
|
||||||
OpenAIMixin,
|
OpenAIMixin,
|
||||||
|
ModelRegistryHelper,
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
|
@ -91,8 +85,10 @@ class OllamaInferenceAdapter(
|
||||||
__provider_id__: str
|
__provider_id__: str
|
||||||
|
|
||||||
def __init__(self, config: OllamaImplConfig) -> None:
|
def __init__(self, config: OllamaImplConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||||
|
self.download_images = True
|
||||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -171,6 +167,7 @@ class OllamaInferenceAdapter(
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self._model_cache = {m.identifier: m for m in models} # for fast check_model_availability
|
||||||
return models
|
return models
|
||||||
|
|
||||||
async def health(self) -> HealthResponse:
|
async def health(self) -> HealthResponse:
|
||||||
|
@ -190,9 +187,6 @@ class OllamaInferenceAdapter(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self._clients.clear()
|
self._clients.clear()
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _get_model(self, model_id: str) -> Model:
|
async def _get_model(self, model_id: str) -> Model:
|
||||||
if not self.model_store:
|
if not self.model_store:
|
||||||
raise ValueError("Model store not set")
|
raise ValueError("Model store not set")
|
||||||
|
@ -301,7 +295,7 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
input_dict: dict[str, Any] = {}
|
input_dict: dict[str, Any] = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
llama_model = self.register_helper.get_llama_model(request.model)
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present or not llama_model:
|
if media_present or not llama_model:
|
||||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||||
|
@ -410,7 +404,7 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
try:
|
try:
|
||||||
model = await self.register_helper.register_model(model)
|
model = await super().register_model(model)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass # Ignore statically unknown model, will check live listing
|
pass # Ignore statically unknown model, will check live listing
|
||||||
|
|
||||||
|
@ -441,75 +435,6 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
model_obj = await self._get_model(model)
|
|
||||||
|
|
||||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
|
||||||
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
|
||||||
if isinstance(m.content, list):
|
|
||||||
for c in m.content:
|
|
||||||
if c.type == "image_url" and c.image_url and c.image_url.url:
|
|
||||||
localize_result = await localize_image_content(c.image_url.url)
|
|
||||||
if localize_result is None:
|
|
||||||
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
|
|
||||||
|
|
||||||
content, format = localize_result
|
|
||||||
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
||||||
return m
|
|
||||||
|
|
||||||
messages = [await _convert_message(m) for m in messages]
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
return await OpenAIMixin.openai_chat_completion(self, **params)
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
@ -25,6 +26,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="providers::utils")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
@ -50,6 +52,10 @@ class OpenAIMixin(ABC):
|
||||||
# This is useful for providers that do not return a unique id in the response.
|
# This is useful for providers that do not return a unique id in the response.
|
||||||
overwrite_completion_id: bool = False
|
overwrite_completion_id: bool = False
|
||||||
|
|
||||||
|
# Allow subclasses to control whether to download images and convert to base64
|
||||||
|
# for providers that require base64 encoded images instead of URLs.
|
||||||
|
download_images: bool = False
|
||||||
|
|
||||||
# Cache of available models keyed by model ID
|
# Cache of available models keyed by model ID
|
||||||
# This is set in list_models() and used in check_model_availability()
|
# This is set in list_models() and used in check_model_availability()
|
||||||
_model_cache: dict[str, Model] = {}
|
_model_cache: dict[str, Model] = {}
|
||||||
|
@ -230,6 +236,24 @@ class OpenAIMixin(ABC):
|
||||||
"""
|
"""
|
||||||
Direct OpenAI chat completion API call.
|
Direct OpenAI chat completion API call.
|
||||||
"""
|
"""
|
||||||
|
if self.download_images:
|
||||||
|
|
||||||
|
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||||
|
if isinstance(m.content, list):
|
||||||
|
for c in m.content:
|
||||||
|
if c.type == "image_url" and c.image_url and c.image_url.url and "http" in c.image_url.url:
|
||||||
|
localize_result = await localize_image_content(c.image_url.url)
|
||||||
|
if localize_result is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to localize image content from {c.image_url.url[:42]}{'...' if len(c.image_url.url) > 42 else ''}"
|
||||||
|
)
|
||||||
|
content, format = localize_result
|
||||||
|
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
||||||
|
# else it's a string and we don't need to modify it
|
||||||
|
return m
|
||||||
|
|
||||||
|
messages = [await _localize_image_url(m) for m in messages]
|
||||||
|
|
||||||
resp = await self.client.chat.completions.create(
|
resp = await self.client.chat.completions.create(
|
||||||
**await prepare_openai_completion_params(
|
**await prepare_openai_completion_params(
|
||||||
model=await self._get_provider_model_id(model),
|
model=await self._get_provider_model_id(model),
|
||||||
|
|
77
tests/integration/inference/test_openai_vision_inference.py
Normal file
77
tests/integration/inference/test_openai_vision_inference.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def image_path():
|
||||||
|
return pathlib.Path(__file__).parent / "dog.png"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base64_image_data(image_path):
|
||||||
|
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_chat_completion_image_url(openai_client, vision_model_id):
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe what is in this image.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model=vision_model_id,
|
||||||
|
messages=[message],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
message_content = response.choices[0].message.content.lower().strip()
|
||||||
|
assert len(message_content) > 0
|
||||||
|
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_chat_completion_image_data(openai_client, vision_model_id, base64_image_data):
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{base64_image_data}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe what is in this image.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model=vision_model_id,
|
||||||
|
messages=[message],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
message_content = response.choices[0].message.content.lower().strip()
|
||||||
|
assert len(message_content) > 0
|
||||||
|
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
67
tests/integration/recordings/responses/d927b47032de.json
Normal file
67
tests/integration/recordings/responses/d927b47032de.json
Normal file
File diff suppressed because one or more lines are too long
|
@ -4,11 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from unittest.mock import MagicMock, PropertyMock, patch
|
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.inference import Model
|
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
@ -27,8 +27,17 @@ class OpenAIMixinImpl(OpenAIMixin):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mixin():
|
def mixin():
|
||||||
"""Create a test instance of OpenAIMixin"""
|
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||||
return OpenAIMixinImpl()
|
mixin_instance = OpenAIMixinImpl()
|
||||||
|
|
||||||
|
# just enough to satisfy _get_provider_model_id calls
|
||||||
|
mock_model_store = MagicMock()
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||||
|
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||||
|
mixin_instance.model_store = mock_model_store
|
||||||
|
|
||||||
|
return mixin_instance
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -181,3 +190,71 @@ class TestOpenAIMixinCacheBehavior:
|
||||||
assert "some-mock-model-id" in mixin._model_cache
|
assert "some-mock-model-id" in mixin._model_cache
|
||||||
assert "another-mock-model-id" in mixin._model_cache
|
assert "another-mock-model-id" in mixin._model_cache
|
||||||
assert "final-mock-model-id" in mixin._model_cache
|
assert "final-mock-model-id" in mixin._model_cache
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIMixinImagePreprocessing:
|
||||||
|
"""Test cases for image preprocessing functionality"""
|
||||||
|
|
||||||
|
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
|
||||||
|
"""Test that image URLs are converted to base64 when download_images is True"""
|
||||||
|
mixin.download_images = True
|
||||||
|
|
||||||
|
message = OpenAIUserMessageParam(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||||
|
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||||
|
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
||||||
|
|
||||||
|
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||||
|
|
||||||
|
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
||||||
|
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
call_args = mock_client.chat.completions.create.call_args
|
||||||
|
processed_messages = call_args[1]["messages"]
|
||||||
|
assert len(processed_messages) == 1
|
||||||
|
content = processed_messages[0]["content"]
|
||||||
|
assert len(content) == 2
|
||||||
|
assert content[0]["type"] == "text"
|
||||||
|
assert content[1]["type"] == "image_url"
|
||||||
|
assert content[1]["image_url"]["url"] == ""
|
||||||
|
|
||||||
|
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
|
||||||
|
"""Test that image URLs are not modified when download_images is False"""
|
||||||
|
mixin.download_images = False # explicitly set to False
|
||||||
|
|
||||||
|
message = OpenAIUserMessageParam(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||||
|
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||||
|
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||||
|
|
||||||
|
mock_localize.assert_not_called()
|
||||||
|
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
call_args = mock_client.chat.completions.create.call_args
|
||||||
|
processed_messages = call_args[1]["messages"]
|
||||||
|
assert len(processed_messages) == 1
|
||||||
|
content = processed_messages[0]["content"]
|
||||||
|
assert len(content) == 2
|
||||||
|
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue