mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
feat - add multimodel for ollama chat
This commit is contained in:
parent
7e9d8b58f6
commit
5d71076347
2 changed files with 148 additions and 1 deletions
|
@ -213,6 +213,44 @@ class OllamaChatConfig:
|
|||
return optional_params
|
||||
|
||||
|
||||
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
|
||||
# and convert to jpeg if necessary.
|
||||
def _convert_image(image):
|
||||
import base64
|
||||
import io
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"ollama image conversion failed please run `pip install Pillow`"
|
||||
)
|
||||
|
||||
orig = image
|
||||
if image.startswith("data:"):
|
||||
image = image.split(",")[-1]
|
||||
try:
|
||||
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
|
||||
if image_data.format in ["JPEG", "PNG"]:
|
||||
return image
|
||||
except Exception:
|
||||
return orig
|
||||
jpeg_image = io.BytesIO()
|
||||
image_data.convert("RGB").save(jpeg_image, "JPEG")
|
||||
jpeg_image.seek(0)
|
||||
return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def _get_messages_prompt(message):
|
||||
prompt = ""
|
||||
message_content = message.get("content", None)
|
||||
if message_content and isinstance(message_content, list):
|
||||
for content in message_content:
|
||||
prompt += content.get("text", "")
|
||||
elif message_content and isinstance(message_content, str):
|
||||
prompt = message_content
|
||||
return prompt
|
||||
|
||||
# ollama implementation
|
||||
def get_ollama_response( # noqa: PLR0915
|
||||
model_response: litellm.ModelResponse,
|
||||
|
@ -249,6 +287,20 @@ def get_ollama_response( # noqa: PLR0915
|
|||
m, BaseModel
|
||||
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
|
||||
m = m.model_dump(exclude_none=True)
|
||||
if m["role"] == "user":
|
||||
## translate user message
|
||||
message_content = m.get("content")
|
||||
if message_content and isinstance(message_content, list):
|
||||
user_text = ""
|
||||
images = []
|
||||
for content in message_content:
|
||||
if content["type"] == "text":
|
||||
user_text += content["text"]
|
||||
elif content["type"] == "image_url":
|
||||
images.append(content["image_url"]["url"])
|
||||
m["content"] = user_text
|
||||
if images:
|
||||
m["images"] = [_convert_image(image) for image in images]
|
||||
if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list):
|
||||
new_tools: List[OllamaToolCall] = []
|
||||
for tool in m["tool_calls"]:
|
||||
|
@ -363,7 +415,9 @@ def get_ollama_response( # noqa: PLR0915
|
|||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = "ollama_chat/" + model
|
||||
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(
|
||||
encoding.encode("".join(_get_messages_prompt(prompt) for prompt in messages))
|
||||
))
|
||||
completion_tokens = response_json.get(
|
||||
"eval_count", litellm.token_counter(text=response_json["message"]["content"])
|
||||
)
|
||||
|
|
93
litellm/tests/test_ollama_chat.py
Normal file
93
litellm/tests/test_ollama_chat.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
import pytest
|
||||
|
||||
import litellm
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_ollama_chat_image():
|
||||
"""
|
||||
Test that datauri prefixes are removed, JPEG/PNG images are passed
|
||||
through, and other image formats are converted to JPEG. Non-image
|
||||
data is untouched.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
def mock_post(url, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"model": "llama3.2-vision:11b",
|
||||
"created_at": "2024-11-23T13:16:28.5525725Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The image is a blank",
|
||||
"images": kwargs["json"]["messages"][0]["images"]
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": True,
|
||||
"total_duration": 74458830900,
|
||||
"load_duration": 18295722500,
|
||||
"prompt_eval_count": 17,
|
||||
"prompt_eval_duration": 9979000000,
|
||||
"eval_count": 104,
|
||||
"eval_duration": 46036000000,
|
||||
}
|
||||
return mock_response
|
||||
|
||||
def make_b64image(format):
|
||||
image = Image.new(mode="RGB", size=(1, 1))
|
||||
image_buffer = io.BytesIO()
|
||||
image.save(image_buffer, format)
|
||||
return base64.b64encode(image_buffer.getvalue()).decode("utf-8")
|
||||
|
||||
jpeg_image = make_b64image("JPEG")
|
||||
webp_image = make_b64image("WEBP")
|
||||
png_image = make_b64image("PNG")
|
||||
|
||||
base64_data = base64.b64encode(b"some random data")
|
||||
datauri_base64_data = f"data:text/plain;base64,{base64_data}"
|
||||
|
||||
tests = [
|
||||
# input expected
|
||||
[jpeg_image, jpeg_image],
|
||||
[webp_image, None],
|
||||
[png_image, png_image],
|
||||
[f"data:image/jpeg;base64,{jpeg_image}", jpeg_image],
|
||||
[f"data:image/webp;base64,{webp_image}", None],
|
||||
[f"data:image/png;base64,{png_image}", png_image],
|
||||
[datauri_base64_data, datauri_base64_data],
|
||||
]
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
with patch("requests.post", side_effect=mock_post):
|
||||
response = litellm.completion(
|
||||
model="ollama_chat/llava",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": test[0]},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
if not test[1]:
|
||||
# the conversion process may not always generate the same image,
|
||||
# so just check for a JPEG image when a conversion was done.
|
||||
image_data = response["choices"][0]["message"]["images"][0]
|
||||
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
|
||||
assert image.format == "JPEG"
|
||||
else:
|
||||
assert response["choices"][0]["message"]["images"][0] == test[1]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
Loading…
Add table
Add a link
Reference in a new issue