Litellm dev 03 08 2025 p3 (#9089)

* feat(ollama_chat.py): pass down http client to ollama_chat

enables easier testing

* fix(factory.py): fix passing images to ollama's `/api/generate` endpoint

Fixes https://github.com/BerriAI/litellm/issues/6683

* fix(factory.py): fix ollama pt to handle templating correctly
This commit is contained in:
Krish Dholakia 2025-03-09 18:20:56 -07:00 committed by GitHub
parent 93273723cd
commit e00d4fb18c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 165 additions and 52 deletions

View file

@ -187,53 +187,125 @@ def ollama_pt(
final_prompt_value="### Response:",
messages=messages,
)
elif "llava" in model:
prompt = ""
images = []
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
if isinstance(element, dict):
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
base64_image = convert_to_ollama_image(
element["image_url"]["url"]
)
images.append(base64_image)
return {"prompt": prompt, "images": images}
else:
user_message_types = {"user", "tool", "function"}
msg_i = 0
images = []
prompt = ""
for message in messages:
role = message["role"]
content = message.get("content", "")
while msg_i < len(messages):
init_msg_i = msg_i
user_content_str = ""
## MERGE CONSECUTIVE USER CONTENT ##
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "image_url":
if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
user_content_str += m["text"]
else:
# Tool message content will always be a string
user_content_str += msg_content
if "tool_calls" in message:
tool_calls = []
msg_i += 1
for call in message["tool_calls"]:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
if user_content_str:
prompt += f"### User:\n{user_content_str}\n\n"
tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {"name": function_name, "arguments": arguments},
}
assistant_content_str = ""
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "text":
assistant_content_str += m["text"]
elif isinstance(msg_content, str):
# Tool message content will always be a string
assistant_content_str += msg_content
tool_calls = messages[msg_i].get("tool_calls")
ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
ollama_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
if ollama_tool_calls:
assistant_content_str += (
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
)
prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
msg_i += 1
elif "tool_call_id" in message:
prompt += f"### User:\n{message['content']}\n\n"
if assistant_content_str:
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
elif content:
prompt += f"### {role.capitalize()}:\n{content}\n\n"
if msg_i == init_msg_i: # prevent infinite loops
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider="ollama",
)
# prompt = ""
# images = []
# for message in messages:
# if isinstance(message["content"], str):
# prompt += message["content"]
# elif isinstance(message["content"], list):
# # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
# for element in message["content"]:
# if isinstance(element, dict):
# if element["type"] == "text":
# prompt += element["text"]
# elif element["type"] == "image_url":
# base64_image = convert_to_ollama_image(
# element["image_url"]["url"]
# )
# images.append(base64_image)
# if "tool_calls" in message:
# tool_calls = []
# for call in message["tool_calls"]:
# call_id: str = call["id"]
# function_name: str = call["function"]["name"]
# arguments = json.loads(call["function"]["arguments"])
# tool_calls.append(
# {
# "id": call_id,
# "type": "function",
# "function": {"name": function_name, "arguments": arguments},
# }
# )
# prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
# elif "tool_call_id" in message:
# prompt += f"### User:\n{message['content']}\n\n"
return {"prompt": prompt, "images": images}
return prompt

View file

@ -1,7 +1,7 @@
import json
import time
import uuid
from typing import Any, List, Optional
from typing import Any, List, Optional, Union
import aiohttp
import httpx
@ -9,7 +9,11 @@ from pydantic import BaseModel
import litellm
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
@ -205,6 +209,7 @@ def get_ollama_response( # noqa: PLR0915
api_key: Optional[str] = None,
acompletion: bool = False,
encoding=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
if api_base.endswith("/api/chat"):
url = api_base
@ -301,7 +306,11 @@ def get_ollama_response( # noqa: PLR0915
headers: Optional[dict] = None
if api_key is not None:
headers = {"Authorization": "Bearer {}".format(api_key)}
response = litellm.module_level_client.post(
sync_client = litellm.module_level_client
if client is not None and isinstance(client, HTTPHandler):
sync_client = client
response = sync_client.post(
url=url,
json=data,
headers=headers,

View file

@ -2856,6 +2856,7 @@ def completion( # type: ignore # noqa: PLR0915
acompletion=acompletion,
model_response=model_response,
encoding=encoding,
client=client,
)
if acompletion is True or optional_params.get("stream", False) is True:
return generator

View file

@ -1,13 +1,4 @@
model_list:
- model_name: openai/gpt-4o
- model_name: llama3.2-vision
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
files_settings:
- custom_llm_provider: azure
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
general_settings:
store_prompts_in_spend_logs: true
model: ollama/llama3.2-vision

View file

@ -1,4 +1,5 @@
import asyncio
import json
import os
import sys
import traceback
@ -76,6 +77,45 @@ def test_ollama_json_mode():
# test_ollama_json_mode()
def test_ollama_vision_model():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
from unittest.mock import patch
with patch.object(client, "post") as mock_post:
try:
litellm.completion(
model="ollama/llama3.2-vision:11b",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://dummyimage.com/100/100/fff&text=Test+image"
},
},
],
}
],
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called()
print(mock_post.call_args.kwargs)
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert json_data["model"] == "llama3.2-vision:11b"
assert "images" in json_data
assert "prompt" in json_data
assert json_data["prompt"].startswith("### User:\n")
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")