From b333a3c03ae29c0e448ac58053246286e4e8c5b8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 30 Jun 2025 20:36:11 +0530 Subject: [PATCH] fix(ollama): Download remote image URLs for Ollama (#2551) ## What does this PR do? Ollama does not support remote images. Only local file paths OR base64 inputs are supported. This PR ensures that the Stack downloads remote images and passes the base64 down to the inference engine. ## Test Plan Added a test cases for Responses and ran it for both `fireworks` and `ollama` providers. --- .../remote/inference/ollama/ollama.py | 17 ++++++++++++ .../utils/inference/prompt_adapter.py | 26 ++++++++++++------- .../fixtures/test_cases/responses.yaml | 11 ++++++++ 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 2d83bf82b..c7717479a 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,6 +5,7 @@ # the root directory of this source tree. +import base64 import uuid from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -77,6 +78,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( content_has_media, convert_image_content_to_url, interleaved_content_as_str, + localize_image_content, request_has_media, ) @@ -496,6 +498,21 @@ class OllamaInferenceAdapter( user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: model_obj = await self._get_model(model) + + # Ollama does not support image urls, so we need to download the image and convert it to base64 + async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam: + if isinstance(m.content, list): + for c in m.content: + if c.type == "image_url" and c.image_url and c.image_url.url: + localize_result = await localize_image_content(c.image_url.url) + if localize_result is None: + raise ValueError(f"Failed to localize image content from {c.image_url.url}") + + content, format = localize_result + c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}" + return m + + messages = [await _convert_message(m) for m in messages] params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, messages=messages, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 56e33cfdf..bb9a91b97 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -180,11 +180,10 @@ def request_has_media(request: ChatCompletionRequest | CompletionRequest): return content_has_media(request.content) -async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]: - image = media.image - if image.url and image.url.uri.startswith("http"): +async def localize_image_content(uri: str) -> tuple[bytes, str] | None: + if uri.startswith("http"): async with httpx.AsyncClient() as client: - r = await client.get(image.url.uri) + r = await client.get(uri) content = r.content content_type = r.headers.get("content-type") if content_type: @@ -194,11 +193,7 @@ async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]: return content, format else: - # data is a base64 encoded string, decode it to bytes first - # TODO(mf): do this more efficiently, decode less - data_bytes = base64.b64decode(image.data) - pil_image = PIL_Image.open(io.BytesIO(data_bytes)) - return data_bytes, pil_image.format + return None async def convert_image_content_to_url( @@ -208,7 +203,18 @@ async def convert_image_content_to_url( if image.url and (not download or image.url.uri.startswith("data")): return image.url.uri - content, format = await localize_image_content(media) + if image.data: + # data is a base64 encoded string, decode it to bytes first + # TODO(mf): do this more efficiently, decode less + content = base64.b64decode(image.data) + pil_image = PIL_Image.open(io.BytesIO(content)) + format = pil_image.format + else: + localize_result = await localize_image_content(image.url.uri) + if localize_result is None: + raise ValueError(f"Failed to localize image content from {image.url.uri}") + content, format = localize_result + if include_format: return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8") else: diff --git a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml index 1acf06388..6db0dd970 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml @@ -8,6 +8,17 @@ test_response_basic: - case_id: "saturn" input: "Which planet has rings around it with a name starting with letter S?" output: "saturn" + - case_id: "image_input" + input: + - role: user + content: + - type: input_text + text: "what teams are playing in this image?" + - role: user + content: + - type: input_image + image_url: "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg" + output: "brooklyn nets" test_response_multi_turn: test_name: test_response_multi_turn