mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(factory.py): correct indentation for message index increment in ollama, This fixes bug #9822 (#9943)
* fix(factory.py): correct indentation for message index increment in ollama_pt function * test: add unit tests for ollama_pt function handling various message types
This commit is contained in:
parent
69a3aab4c8
commit
433075a8d9
2 changed files with 218 additions and 1 deletions
|
@ -250,7 +250,7 @@ def ollama_pt(
|
|||
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
|
||||
)
|
||||
|
||||
msg_i += 1
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content_str:
|
||||
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
|
||||
|
|
|
@ -0,0 +1,217 @@
|
|||
import pytest
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import ollama_pt, BAD_MESSAGE_ERROR_STR
|
||||
|
||||
def test_ollama_pt_simple_messages():
|
||||
"""Test basic functionality with simple text messages"""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "assistant", "content": "How can I help you?"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
|
||||
result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
expected_prompt = "### System:\nYou are a helpful assistant\n\n### Assistant:\nHow can I help you?\n\n### User:\nHello\n\n"
|
||||
assert isinstance(result, dict)
|
||||
assert result["prompt"] == expected_prompt
|
||||
assert result["images"] == []
|
||||
|
||||
def test_ollama_pt_consecutive_user_messages():
|
||||
"""Test handling consecutive user messages"""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "How can I help you?"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
{"role": "assistant", "content": "I'm good, thanks!"},
|
||||
{"role": "user", "content": "I am well too."},
|
||||
]
|
||||
|
||||
result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# Consecutive user messages should be merged
|
||||
expected_prompt = "### User:\nHello\n\n### Assistant:\nHow can I help you?\n\n### User:\nHow are you?\n\n### Assistant:\nI'm good, thanks!\n\n### User:\nI am well too.\n\n"
|
||||
assert isinstance(result, dict)
|
||||
assert result["prompt"] == expected_prompt
|
||||
|
||||
# def test_ollama_pt_consecutive_system_messages():
|
||||
# """Test handling consecutive system messages"""
|
||||
# messages = [
|
||||
# {"role": "user", "content": "Hello"},
|
||||
# {"role": "system", "content": "You are a helpful assistant"},
|
||||
# {"role": "system", "content": "Be concise and polite"},
|
||||
# {"role": "assistant", "content": "How can I help you?"}
|
||||
# ]
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# # Consecutive system messages should be merged
|
||||
# expected_prompt = "### User:\nHello\n\n### System:\nYou are a helpful assistantBe concise and polite\n\n### Assistant:\nHow can I help you?\n\n"
|
||||
# assert result == expected_prompt
|
||||
|
||||
# def test_ollama_pt_consecutive_assistant_messages():
|
||||
# """Test handling consecutive assistant messages"""
|
||||
# messages = [
|
||||
# {"role": "user", "content": "Hello"},
|
||||
# {"role": "assistant", "content": "Hi there!"},
|
||||
# {"role": "assistant", "content": "How can I help you?"},
|
||||
# {"role": "user", "content": "Tell me a joke"}
|
||||
# ]
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# # Consecutive assistant messages should be merged
|
||||
# expected_prompt = "### User:\nHello\n\n### Assistant:\nHi there!How can I help you?\n\n### User:\nTell me a joke\n\n"
|
||||
# assert result["prompt"] == expected_prompt
|
||||
|
||||
# def test_ollama_pt_with_image_urls_as_strings():
|
||||
# """Test handling messages with image URLs as strings"""
|
||||
# messages = [
|
||||
# {"role": "user", "content": [
|
||||
# {"type": "text", "text": "What's in this image?"},
|
||||
# {"type": "image_url", "image_url": "http://example.com/image.jpg"}
|
||||
# ]},
|
||||
# {"role": "assistant", "content": "That's a cat."}
|
||||
# ]
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# expected_prompt = "### User:\nWhat's in this image?\n\n### Assistant:\nThat's a cat.\n\n"
|
||||
# assert result["prompt"] == expected_prompt
|
||||
# assert result["images"] == ["http://example.com/image.jpg"]
|
||||
|
||||
# def test_ollama_pt_with_image_urls_as_dicts():
|
||||
# """Test handling messages with image URLs as dictionaries"""
|
||||
# messages = [
|
||||
# {"role": "user", "content": [
|
||||
# {"type": "text", "text": "What's in this image?"},
|
||||
# {"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}}
|
||||
# ]},
|
||||
# {"role": "assistant", "content": "That's a cat."}
|
||||
# ]
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# expected_prompt = "### User:\nWhat's in this image?\n\n### Assistant:\nThat's a cat.\n\n"
|
||||
# assert result["prompt"] == expected_prompt
|
||||
# assert result["images"] == ["http://example.com/image.jpg"]
|
||||
|
||||
# def test_ollama_pt_with_tool_calls():
|
||||
# """Test handling messages with tool calls"""
|
||||
# messages = [
|
||||
# {"role": "user", "content": "What's the weather in San Francisco?"},
|
||||
# {"role": "assistant", "content": "I'll check the weather for you.",
|
||||
# "tool_calls": [
|
||||
# {
|
||||
# "id": "call_123",
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "get_weather",
|
||||
# "arguments": json.dumps({"location": "San Francisco"})
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {"role": "tool", "content": "Sunny, 72°F"}
|
||||
# ]
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# # Check if tool call is included in the prompt
|
||||
# assert "### User:\nWhat's the weather in San Francisco?" in result["prompt"]
|
||||
# assert "### Assistant:\nI'll check the weather for you.Tool Calls:" in result["prompt"]
|
||||
# assert "get_weather" in result["prompt"]
|
||||
# assert "San Francisco" in result["prompt"]
|
||||
# assert "### User:\nSunny, 72°F\n\n" in result["prompt"]
|
||||
|
||||
# def test_ollama_pt_error_handling():
|
||||
# """Test error handling for invalid messages"""
|
||||
# messages = [
|
||||
# {"role": "invalid_role", "content": "This is an invalid role"}
|
||||
# ]
|
||||
|
||||
# with pytest.raises(litellm.BadRequestError) as excinfo:
|
||||
# ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# assert BAD_MESSAGE_ERROR_STR in str(excinfo.value)
|
||||
|
||||
# def test_ollama_pt_empty_messages():
|
||||
# """Test with empty messages list"""
|
||||
# messages = []
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# assert result["prompt"] == ""
|
||||
# assert result["images"] == []
|
||||
|
||||
# def test_ollama_pt_with_tool_message_content():
|
||||
# """Test handling tool message content"""
|
||||
# messages = [
|
||||
# {"role": "user", "content": "Tell me a joke"},
|
||||
# {"role": "assistant", "content": "Why did the chicken cross the road?"},
|
||||
# {"role": "user", "content": "Why?"},
|
||||
# {"role": "assistant", "content": "To get to the other side!"},
|
||||
# {"role": "tool", "content": "Joke rating: 5/10"}
|
||||
# ]
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# assert "### User:\nTell me a joke" in result["prompt"]
|
||||
# assert "### Assistant:\nWhy did the chicken cross the road?" in result["prompt"]
|
||||
# assert "### User:\nWhy?" in result["prompt"]
|
||||
# assert "### Assistant:\nTo get to the other side!" in result["prompt"]
|
||||
# assert "### User:\nJoke rating: 5/10\n\n" in result["prompt"]
|
||||
|
||||
# def test_ollama_pt_with_function_message():
|
||||
# """Test handling function messages (treated as user message type)"""
|
||||
# messages = [
|
||||
# {"role": "user", "content": "What's 2+2?"},
|
||||
# {"role": "function", "content": "The result is 4"},
|
||||
# {"role": "assistant", "content": "The answer is 4."}
|
||||
# ]
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# assert "### User:\nWhat's 2+2?The result is 4\n\n" in result["prompt"]
|
||||
# assert "### Assistant:\nThe answer is 4.\n\n" in result["prompt"]
|
||||
|
||||
# def test_ollama_pt_with_multiple_images():
|
||||
# """Test handling multiple images in a message"""
|
||||
# messages = [
|
||||
# {"role": "user", "content": [
|
||||
# {"type": "text", "text": "Compare these images:"},
|
||||
# {"type": "image_url", "image_url": "http://example.com/image1.jpg"},
|
||||
# {"type": "image_url", "image_url": "http://example.com/image2.jpg"}
|
||||
# ]},
|
||||
# {"role": "assistant", "content": "Both images show cats, but different breeds."}
|
||||
# ]
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# expected_prompt = "### User:\nCompare these images:\n\n### Assistant:\nBoth images show cats, but different breeds.\n\n"
|
||||
# assert result["prompt"] == expected_prompt
|
||||
# assert result["images"] == ["http://example.com/image1.jpg", "http://example.com/image2.jpg"]
|
||||
|
||||
# def test_ollama_pt_mixed_content_types():
|
||||
# """Test handling a mix of string and list content types"""
|
||||
# messages = [
|
||||
# {"role": "user", "content": "Hello"},
|
||||
# {"role": "assistant", "content": "Hi there!"},
|
||||
# {"role": "user", "content": [
|
||||
# {"type": "text", "text": "Look at this:"},
|
||||
# {"type": "image_url", "image_url": "http://example.com/image.jpg"}
|
||||
# ]},
|
||||
# {"role": "system", "content": "Be helpful"},
|
||||
# {"role": "assistant", "content": "I see a cat in the image."}
|
||||
# ]
|
||||
|
||||
# result = ollama_pt(model="llama2", messages=messages)
|
||||
|
||||
# assert "### User:\nHello\n\n" in result["prompt"]
|
||||
# assert "### Assistant:\nHi there!\n\n" in result["prompt"]
|
||||
# assert "### User:\nLook at this:\n\n" in result["prompt"]
|
||||
# assert "### System:\nBe helpful\n\n" in result["prompt"]
|
||||
# assert "### Assistant:\nI see a cat in the image.\n\n" in result["prompt"]
|
||||
# assert result["images"] == ["http://example.com/image.jpg"]
|
Loading…
Add table
Add a link
Reference in a new issue