diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 2d83bf82b..c7717479a 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,6 +5,7 @@ # the root directory of this source tree. +import base64 import uuid from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -77,6 +78,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( content_has_media, convert_image_content_to_url, interleaved_content_as_str, + localize_image_content, request_has_media, ) @@ -496,6 +498,21 @@ class OllamaInferenceAdapter( user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: model_obj = await self._get_model(model) + + # Ollama does not support image urls, so we need to download the image and convert it to base64 + async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam: + if isinstance(m.content, list): + for c in m.content: + if c.type == "image_url" and c.image_url and c.image_url.url: + localize_result = await localize_image_content(c.image_url.url) + if localize_result is None: + raise ValueError(f"Failed to localize image content from {c.image_url.url}") + + content, format = localize_result + c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}" + return m + + messages = [await _convert_message(m) for m in messages] params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, messages=messages, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 56e33cfdf..bb9a91b97 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -180,11 +180,10 @@ def request_has_media(request: ChatCompletionRequest | CompletionRequest): return content_has_media(request.content) -async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]: - image = media.image - if image.url and image.url.uri.startswith("http"): +async def localize_image_content(uri: str) -> tuple[bytes, str] | None: + if uri.startswith("http"): async with httpx.AsyncClient() as client: - r = await client.get(image.url.uri) + r = await client.get(uri) content = r.content content_type = r.headers.get("content-type") if content_type: @@ -194,11 +193,7 @@ async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]: return content, format else: - # data is a base64 encoded string, decode it to bytes first - # TODO(mf): do this more efficiently, decode less - data_bytes = base64.b64decode(image.data) - pil_image = PIL_Image.open(io.BytesIO(data_bytes)) - return data_bytes, pil_image.format + return None async def convert_image_content_to_url( @@ -208,7 +203,18 @@ async def convert_image_content_to_url( if image.url and (not download or image.url.uri.startswith("data")): return image.url.uri - content, format = await localize_image_content(media) + if image.data: + # data is a base64 encoded string, decode it to bytes first + # TODO(mf): do this more efficiently, decode less + content = base64.b64decode(image.data) + pil_image = PIL_Image.open(io.BytesIO(content)) + format = pil_image.format + else: + localize_result = await localize_image_content(image.url.uri) + if localize_result is None: + raise ValueError(f"Failed to localize image content from {image.url.uri}") + content, format = localize_result + if include_format: return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8") else: diff --git a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml index 1acf06388..6db0dd970 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml @@ -8,6 +8,17 @@ test_response_basic: - case_id: "saturn" input: "Which planet has rings around it with a name starting with letter S?" output: "saturn" + - case_id: "image_input" + input: + - role: user + content: + - type: input_text + text: "what teams are playing in this image?" + - role: user + content: + - type: input_image + image_url: "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg" + output: "brooklyn nets" test_response_multi_turn: test_name: test_response_multi_turn