mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
fix(ollama): Download remote image URLs for Ollama
This commit is contained in:
parent
6fa5271807
commit
a17894a7b0
3 changed files with 44 additions and 10 deletions
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -77,6 +78,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
content_has_media,
|
content_has_media,
|
||||||
convert_image_content_to_url,
|
convert_image_content_to_url,
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
localize_image_content,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -496,6 +498,21 @@ class OllamaInferenceAdapter(
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
|
|
||||||
|
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||||
|
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||||
|
if isinstance(m.content, list):
|
||||||
|
for c in m.content:
|
||||||
|
if c.type == "image_url" and c.image_url.url:
|
||||||
|
localize_result = await localize_image_content(c.image_url.url)
|
||||||
|
if localize_result is None:
|
||||||
|
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
|
||||||
|
|
||||||
|
content, format = localize_result
|
||||||
|
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
||||||
|
return m
|
||||||
|
|
||||||
|
messages = [await _convert_message(m) for m in messages]
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=model_obj.provider_resource_id,
|
model=model_obj.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -180,11 +180,10 @@ def request_has_media(request: ChatCompletionRequest | CompletionRequest):
|
||||||
return content_has_media(request.content)
|
return content_has_media(request.content)
|
||||||
|
|
||||||
|
|
||||||
async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]:
|
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
||||||
image = media.image
|
if uri.startswith("http"):
|
||||||
if image.url and image.url.uri.startswith("http"):
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(image.url.uri)
|
r = await client.get(uri)
|
||||||
content = r.content
|
content = r.content
|
||||||
content_type = r.headers.get("content-type")
|
content_type = r.headers.get("content-type")
|
||||||
if content_type:
|
if content_type:
|
||||||
|
@ -194,11 +193,7 @@ async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]:
|
||||||
|
|
||||||
return content, format
|
return content, format
|
||||||
else:
|
else:
|
||||||
# data is a base64 encoded string, decode it to bytes first
|
return None
|
||||||
# TODO(mf): do this more efficiently, decode less
|
|
||||||
data_bytes = base64.b64decode(image.data)
|
|
||||||
pil_image = PIL_Image.open(io.BytesIO(data_bytes))
|
|
||||||
return data_bytes, pil_image.format
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_image_content_to_url(
|
async def convert_image_content_to_url(
|
||||||
|
@ -208,7 +203,18 @@ async def convert_image_content_to_url(
|
||||||
if image.url and (not download or image.url.uri.startswith("data")):
|
if image.url and (not download or image.url.uri.startswith("data")):
|
||||||
return image.url.uri
|
return image.url.uri
|
||||||
|
|
||||||
content, format = await localize_image_content(media)
|
if image.data:
|
||||||
|
# data is a base64 encoded string, decode it to bytes first
|
||||||
|
# TODO(mf): do this more efficiently, decode less
|
||||||
|
content = base64.b64decode(image.data)
|
||||||
|
pil_image = PIL_Image.open(io.BytesIO(content))
|
||||||
|
format = pil_image.format
|
||||||
|
else:
|
||||||
|
localize_result = await localize_image_content(image.url.uri)
|
||||||
|
if localize_result is None:
|
||||||
|
raise ValueError(f"Failed to localize image content from {image.url.uri}")
|
||||||
|
content, format = localize_result
|
||||||
|
|
||||||
if include_format:
|
if include_format:
|
||||||
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,6 +8,17 @@ test_response_basic:
|
||||||
- case_id: "saturn"
|
- case_id: "saturn"
|
||||||
input: "Which planet has rings around it with a name starting with letter S?"
|
input: "Which planet has rings around it with a name starting with letter S?"
|
||||||
output: "saturn"
|
output: "saturn"
|
||||||
|
- case_id: "image_input"
|
||||||
|
input:
|
||||||
|
- role: user
|
||||||
|
content:
|
||||||
|
- type: input_text
|
||||||
|
text: "what teams are playing in this image?"
|
||||||
|
- role: user
|
||||||
|
content:
|
||||||
|
- type: input_image
|
||||||
|
image_url: "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg"
|
||||||
|
output: "brooklyn nets"
|
||||||
|
|
||||||
test_response_multi_turn:
|
test_response_multi_turn:
|
||||||
test_name: test_response_multi_turn
|
test_name: test_response_multi_turn
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue