From 5d7107634736f039575ca21a5ee456b4832dc619 Mon Sep 17 00:00:00 2001 From: linznin Date: Sat, 23 Nov 2024 21:37:14 +0800 Subject: [PATCH] feat - add multimodel for ollama chat --- litellm/llms/ollama_chat.py | 56 ++++++++++++++++++- litellm/tests/test_ollama_chat.py | 93 +++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 litellm/tests/test_ollama_chat.py diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index ce0df139d0..783553f7d7 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -213,6 +213,44 @@ class OllamaChatConfig: return optional_params +# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI +# and convert to jpeg if necessary. +def _convert_image(image): + import base64 + import io + + try: + from PIL import Image + except Exception: + raise Exception( + "ollama image conversion failed please run `pip install Pillow`" + ) + + orig = image + if image.startswith("data:"): + image = image.split(",")[-1] + try: + image_data = Image.open(io.BytesIO(base64.b64decode(image))) + if image_data.format in ["JPEG", "PNG"]: + return image + except Exception: + return orig + jpeg_image = io.BytesIO() + image_data.convert("RGB").save(jpeg_image, "JPEG") + jpeg_image.seek(0) + return base64.b64encode(jpeg_image.getvalue()).decode("utf-8") + + +def _get_messages_prompt(message): + prompt = "" + message_content = message.get("content", None) + if message_content and isinstance(message_content, list): + for content in message_content: + prompt += content.get("text", "") + elif message_content and isinstance(message_content, str): + prompt = message_content + return prompt + # ollama implementation def get_ollama_response( # noqa: PLR0915 model_response: litellm.ModelResponse, @@ -249,6 +287,20 @@ def get_ollama_response( # noqa: PLR0915 m, BaseModel ): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319 m = m.model_dump(exclude_none=True) + if m["role"] == "user": + ## translate user message + message_content = m.get("content") + if message_content and isinstance(message_content, list): + user_text = "" + images = [] + for content in message_content: + if content["type"] == "text": + user_text += content["text"] + elif content["type"] == "image_url": + images.append(content["image_url"]["url"]) + m["content"] = user_text + if images: + m["images"] = [_convert_image(image) for image in images] if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list): new_tools: List[OllamaToolCall] = [] for tool in m["tool_calls"]: @@ -363,7 +415,9 @@ def get_ollama_response( # noqa: PLR0915 model_response.choices[0].message = _message # type: ignore model_response.created = int(time.time()) model_response.model = "ollama_chat/" + model - prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore + prompt_tokens = response_json.get("prompt_eval_count", len( + encoding.encode("".join(_get_messages_prompt(prompt) for prompt in messages)) + )) completion_tokens = response_json.get( "eval_count", litellm.token_counter(text=response_json["message"]["content"]) ) diff --git a/litellm/tests/test_ollama_chat.py b/litellm/tests/test_ollama_chat.py new file mode 100644 index 0000000000..a3b8db9824 --- /dev/null +++ b/litellm/tests/test_ollama_chat.py @@ -0,0 +1,93 @@ +import pytest + +import litellm +from unittest.mock import MagicMock, patch + + +def test_ollama_chat_image(): + """ + Test that datauri prefixes are removed, JPEG/PNG images are passed + through, and other image formats are converted to JPEG. Non-image + data is untouched. + """ + + import base64 + import io + + from PIL import Image + + def mock_post(url, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "model": "llama3.2-vision:11b", + "created_at": "2024-11-23T13:16:28.5525725Z", + "message": { + "role": "assistant", + "content": "The image is a blank", + "images": kwargs["json"]["messages"][0]["images"] + }, + "done_reason": "stop", + "done": True, + "total_duration": 74458830900, + "load_duration": 18295722500, + "prompt_eval_count": 17, + "prompt_eval_duration": 9979000000, + "eval_count": 104, + "eval_duration": 46036000000, + } + return mock_response + + def make_b64image(format): + image = Image.new(mode="RGB", size=(1, 1)) + image_buffer = io.BytesIO() + image.save(image_buffer, format) + return base64.b64encode(image_buffer.getvalue()).decode("utf-8") + + jpeg_image = make_b64image("JPEG") + webp_image = make_b64image("WEBP") + png_image = make_b64image("PNG") + + base64_data = base64.b64encode(b"some random data") + datauri_base64_data = f"data:text/plain;base64,{base64_data}" + + tests = [ + # input expected + [jpeg_image, jpeg_image], + [webp_image, None], + [png_image, png_image], + [f"data:image/jpeg;base64,{jpeg_image}", jpeg_image], + [f"data:image/webp;base64,{webp_image}", None], + [f"data:image/png;base64,{png_image}", png_image], + [datauri_base64_data, datauri_base64_data], + ] + + for test in tests: + try: + with patch("requests.post", side_effect=mock_post): + response = litellm.completion( + model="ollama_chat/llava", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": {"url": test[0]}, + }, + ], + } + ], + ) + if not test[1]: + # the conversion process may not always generate the same image, + # so just check for a JPEG image when a conversion was done. + image_data = response["choices"][0]["message"]["images"][0] + image = Image.open(io.BytesIO(base64.b64decode(image_data))) + assert image.format == "JPEG" + else: + assert response["choices"][0]["message"]["images"][0] == test[1] + except Exception as e: + pytest.fail(f"Error occurred: {e}")