This commit is contained in:
linznin 2025-04-24 01:03:19 -07:00 committed by GitHub
commit a631b72e05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 148 additions and 1 deletions

View file

@ -198,6 +198,44 @@ class OllamaChatConfig(OpenAIGPTConfig):
return optional_params
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
# and convert to jpeg if necessary.
def _convert_image(image):
import base64
import io
try:
from PIL import Image
except Exception:
raise Exception(
"ollama image conversion failed please run `pip install Pillow`"
)
orig = image
if image.startswith("data:"):
image = image.split(",")[-1]
try:
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
if image_data.format in ["JPEG", "PNG"]:
return image
except Exception:
return orig
jpeg_image = io.BytesIO()
image_data.convert("RGB").save(jpeg_image, "JPEG")
jpeg_image.seek(0)
return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")
def _get_messages_prompt(message):
prompt = ""
message_content = message.get("content", None)
if message_content and isinstance(message_content, list):
for content in message_content:
prompt += content.get("text", "")
elif message_content and isinstance(message_content, str):
prompt = message_content
return prompt
# ollama implementation
def get_ollama_response( # noqa: PLR0915
model_response: ModelResponse,
@ -236,6 +274,20 @@ def get_ollama_response( # noqa: PLR0915
m, BaseModel
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
m = m.model_dump(exclude_none=True)
if m["role"] == "user":
## translate user message
message_content = m.get("content")
if message_content and isinstance(message_content, list):
user_text = ""
images = []
for content in message_content:
if content["type"] == "text":
user_text += content["text"]
elif content["type"] == "image_url":
images.append(content["image_url"]["url"])
m["content"] = user_text
if images:
m["images"] = [_convert_image(image) for image in images]
if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list):
new_tools: List[OllamaToolCall] = []
for tool in m["tool_calls"]:
@ -357,7 +409,9 @@ def get_ollama_response( # noqa: PLR0915
model_response.choices[0].message = _message # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama_chat/" + model
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
prompt_tokens = response_json.get("prompt_eval_count", len(
encoding.encode("".join(_get_messages_prompt(prompt) for prompt in messages))
))
completion_tokens = response_json.get(
"eval_count", litellm.token_counter(text=response_json["message"]["content"])
)

View file

@ -0,0 +1,93 @@
import pytest
import litellm
from unittest.mock import MagicMock, patch
def test_ollama_chat_image():
"""
Test that datauri prefixes are removed, JPEG/PNG images are passed
through, and other image formats are converted to JPEG. Non-image
data is untouched.
"""
import base64
import io
from PIL import Image
def mock_post(url, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"model": "llama3.2-vision:11b",
"created_at": "2024-11-23T13:16:28.5525725Z",
"message": {
"role": "assistant",
"content": "The image is a blank",
"images": kwargs["json"]["messages"][0]["images"]
},
"done_reason": "stop",
"done": True,
"total_duration": 74458830900,
"load_duration": 18295722500,
"prompt_eval_count": 17,
"prompt_eval_duration": 9979000000,
"eval_count": 104,
"eval_duration": 46036000000,
}
return mock_response
def make_b64image(format):
image = Image.new(mode="RGB", size=(1, 1))
image_buffer = io.BytesIO()
image.save(image_buffer, format)
return base64.b64encode(image_buffer.getvalue()).decode("utf-8")
jpeg_image = make_b64image("JPEG")
webp_image = make_b64image("WEBP")
png_image = make_b64image("PNG")
base64_data = base64.b64encode(b"some random data")
datauri_base64_data = f"data:text/plain;base64,{base64_data}"
tests = [
# input expected
[jpeg_image, jpeg_image],
[webp_image, None],
[png_image, png_image],
[f"data:image/jpeg;base64,{jpeg_image}", jpeg_image],
[f"data:image/webp;base64,{webp_image}", None],
[f"data:image/png;base64,{png_image}", png_image],
[datauri_base64_data, datauri_base64_data],
]
for test in tests:
try:
with patch("requests.post", side_effect=mock_post):
response = litellm.completion(
model="ollama_chat/llava",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {"url": test[0]},
},
],
}
],
)
if not test[1]:
# the conversion process may not always generate the same image,
# so just check for a JPEG image when a conversion was done.
image_data = response["choices"][0]["message"]["images"][0]
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
assert image.format == "JPEG"
else:
assert response["choices"][0]["message"]["images"][0] == test[1]
except Exception as e:
pytest.fail(f"Error occurred: {e}")