mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
fix(ollama): Download remote image URLs for Ollama
This commit is contained in:
parent
6fa5271807
commit
a17894a7b0
3 changed files with 44 additions and 10 deletions
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
@ -77,6 +78,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
content_has_media,
|
||||
convert_image_content_to_url,
|
||||
interleaved_content_as_str,
|
||||
localize_image_content,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -496,6 +498,21 @@ class OllamaInferenceAdapter(
|
|||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||
if isinstance(m.content, list):
|
||||
for c in m.content:
|
||||
if c.type == "image_url" and c.image_url.url:
|
||||
localize_result = await localize_image_content(c.image_url.url)
|
||||
if localize_result is None:
|
||||
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
|
||||
|
||||
content, format = localize_result
|
||||
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
||||
return m
|
||||
|
||||
messages = [await _convert_message(m) for m in messages]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
|
|
|
@ -180,11 +180,10 @@ def request_has_media(request: ChatCompletionRequest | CompletionRequest):
|
|||
return content_has_media(request.content)
|
||||
|
||||
|
||||
async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]:
|
||||
image = media.image
|
||||
if image.url and image.url.uri.startswith("http"):
|
||||
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
||||
if uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(image.url.uri)
|
||||
r = await client.get(uri)
|
||||
content = r.content
|
||||
content_type = r.headers.get("content-type")
|
||||
if content_type:
|
||||
|
@ -194,11 +193,7 @@ async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]:
|
|||
|
||||
return content, format
|
||||
else:
|
||||
# data is a base64 encoded string, decode it to bytes first
|
||||
# TODO(mf): do this more efficiently, decode less
|
||||
data_bytes = base64.b64decode(image.data)
|
||||
pil_image = PIL_Image.open(io.BytesIO(data_bytes))
|
||||
return data_bytes, pil_image.format
|
||||
return None
|
||||
|
||||
|
||||
async def convert_image_content_to_url(
|
||||
|
@ -208,7 +203,18 @@ async def convert_image_content_to_url(
|
|||
if image.url and (not download or image.url.uri.startswith("data")):
|
||||
return image.url.uri
|
||||
|
||||
content, format = await localize_image_content(media)
|
||||
if image.data:
|
||||
# data is a base64 encoded string, decode it to bytes first
|
||||
# TODO(mf): do this more efficiently, decode less
|
||||
content = base64.b64decode(image.data)
|
||||
pil_image = PIL_Image.open(io.BytesIO(content))
|
||||
format = pil_image.format
|
||||
else:
|
||||
localize_result = await localize_image_content(image.url.uri)
|
||||
if localize_result is None:
|
||||
raise ValueError(f"Failed to localize image content from {image.url.uri}")
|
||||
content, format = localize_result
|
||||
|
||||
if include_format:
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
||||
else:
|
||||
|
|
|
@ -8,6 +8,17 @@ test_response_basic:
|
|||
- case_id: "saturn"
|
||||
input: "Which planet has rings around it with a name starting with letter S?"
|
||||
output: "saturn"
|
||||
- case_id: "image_input"
|
||||
input:
|
||||
- role: user
|
||||
content:
|
||||
- type: input_text
|
||||
text: "what teams are playing in this image?"
|
||||
- role: user
|
||||
content:
|
||||
- type: input_image
|
||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg"
|
||||
output: "brooklyn nets"
|
||||
|
||||
test_response_multi_turn:
|
||||
test_name: test_response_multi_turn
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue