fix(ollama): Download remote image URLs for Ollama

This commit is contained in:
Ashwin Bharambe 2025-06-30 19:31:23 +05:30
parent 6fa5271807
commit a17894a7b0
3 changed files with 44 additions and 10 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64
import uuid import uuid
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
@ -77,6 +78,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media, content_has_media,
convert_image_content_to_url, convert_image_content_to_url,
interleaved_content_as_str, interleaved_content_as_str,
localize_image_content,
request_has_media, request_has_media,
) )
@ -496,6 +498,21 @@ class OllamaInferenceAdapter(
user: str | None = None, user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self._get_model(model) model_obj = await self._get_model(model)
# Ollama does not support image urls, so we need to download the image and convert it to base64
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
if isinstance(m.content, list):
for c in m.content:
if c.type == "image_url" and c.image_url.url:
localize_result = await localize_image_content(c.image_url.url)
if localize_result is None:
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
content, format = localize_result
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
return m
messages = [await _convert_message(m) for m in messages]
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,
messages=messages, messages=messages,

View file

@ -180,11 +180,10 @@ def request_has_media(request: ChatCompletionRequest | CompletionRequest):
return content_has_media(request.content) return content_has_media(request.content)
async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]: async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
image = media.image if uri.startswith("http"):
if image.url and image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(image.url.uri) r = await client.get(uri)
content = r.content content = r.content
content_type = r.headers.get("content-type") content_type = r.headers.get("content-type")
if content_type: if content_type:
@ -194,11 +193,7 @@ async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]:
return content, format return content, format
else: else:
# data is a base64 encoded string, decode it to bytes first return None
# TODO(mf): do this more efficiently, decode less
data_bytes = base64.b64decode(image.data)
pil_image = PIL_Image.open(io.BytesIO(data_bytes))
return data_bytes, pil_image.format
async def convert_image_content_to_url( async def convert_image_content_to_url(
@ -208,7 +203,18 @@ async def convert_image_content_to_url(
if image.url and (not download or image.url.uri.startswith("data")): if image.url and (not download or image.url.uri.startswith("data")):
return image.url.uri return image.url.uri
content, format = await localize_image_content(media) if image.data:
# data is a base64 encoded string, decode it to bytes first
# TODO(mf): do this more efficiently, decode less
content = base64.b64decode(image.data)
pil_image = PIL_Image.open(io.BytesIO(content))
format = pil_image.format
else:
localize_result = await localize_image_content(image.url.uri)
if localize_result is None:
raise ValueError(f"Failed to localize image content from {image.url.uri}")
content, format = localize_result
if include_format: if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8") return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
else: else:

View file

@ -8,6 +8,17 @@ test_response_basic:
- case_id: "saturn" - case_id: "saturn"
input: "Which planet has rings around it with a name starting with letter S?" input: "Which planet has rings around it with a name starting with letter S?"
output: "saturn" output: "saturn"
- case_id: "image_input"
input:
- role: user
content:
- type: input_text
text: "what teams are playing in this image?"
- role: user
content:
- type: input_image
image_url: "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg"
output: "brooklyn nets"
test_response_multi_turn: test_response_multi_turn:
test_name: test_response_multi_turn test_name: test_response_multi_turn