Litellm dev 03 08 2025 p3 (#9089)

* feat(ollama_chat.py): pass down http client to ollama_chat

enables easier testing

* fix(factory.py): fix passing images to ollama's `/api/generate` endpoint

Fixes https://github.com/BerriAI/litellm/issues/6683

* fix(factory.py): fix ollama pt to handle templating correctly
This commit is contained in:
Krish Dholakia 2025-03-09 18:20:56 -07:00 committed by GitHub
parent 93273723cd
commit e00d4fb18c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 165 additions and 52 deletions

View file

@ -187,53 +187,125 @@ def ollama_pt(
final_prompt_value="### Response:", final_prompt_value="### Response:",
messages=messages, messages=messages,
) )
elif "llava" in model:
prompt = ""
images = []
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
if isinstance(element, dict):
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
base64_image = convert_to_ollama_image(
element["image_url"]["url"]
)
images.append(base64_image)
return {"prompt": prompt, "images": images}
else: else:
user_message_types = {"user", "tool", "function"}
msg_i = 0
images = []
prompt = "" prompt = ""
for message in messages: while msg_i < len(messages):
role = message["role"] init_msg_i = msg_i
content = message.get("content", "") user_content_str = ""
## MERGE CONSECUTIVE USER CONTENT ##
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "image_url":
if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
user_content_str += m["text"]
else:
# Tool message content will always be a string
user_content_str += msg_content
if "tool_calls" in message: msg_i += 1
tool_calls = []
for call in message["tool_calls"]: if user_content_str:
call_id: str = call["id"] prompt += f"### User:\n{user_content_str}\n\n"
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
tool_calls.append( assistant_content_str = ""
{ ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
"id": call_id, while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
"type": "function", msg_content = messages[msg_i].get("content")
"function": {"name": function_name, "arguments": arguments}, if msg_content:
} if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "text":
assistant_content_str += m["text"]
elif isinstance(msg_content, str):
# Tool message content will always be a string
assistant_content_str += msg_content
tool_calls = messages[msg_i].get("tool_calls")
ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
ollama_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
if ollama_tool_calls:
assistant_content_str += (
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
) )
prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n" msg_i += 1
elif "tool_call_id" in message: if assistant_content_str:
prompt += f"### User:\n{message['content']}\n\n" prompt += f"### Assistant:\n{assistant_content_str}\n\n"
elif content: if msg_i == init_msg_i: # prevent infinite loops
prompt += f"### {role.capitalize()}:\n{content}\n\n" raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider="ollama",
)
# prompt = ""
# images = []
# for message in messages:
# if isinstance(message["content"], str):
# prompt += message["content"]
# elif isinstance(message["content"], list):
# # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
# for element in message["content"]:
# if isinstance(element, dict):
# if element["type"] == "text":
# prompt += element["text"]
# elif element["type"] == "image_url":
# base64_image = convert_to_ollama_image(
# element["image_url"]["url"]
# )
# images.append(base64_image)
# if "tool_calls" in message:
# tool_calls = []
# for call in message["tool_calls"]:
# call_id: str = call["id"]
# function_name: str = call["function"]["name"]
# arguments = json.loads(call["function"]["arguments"])
# tool_calls.append(
# {
# "id": call_id,
# "type": "function",
# "function": {"name": function_name, "arguments": arguments},
# }
# )
# prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
# elif "tool_call_id" in message:
# prompt += f"### User:\n{message['content']}\n\n"
return {"prompt": prompt, "images": images}
return prompt return prompt

View file

@ -1,7 +1,7 @@
import json import json
import time import time
import uuid import uuid
from typing import Any, List, Optional from typing import Any, List, Optional, Union
import aiohttp import aiohttp
import httpx import httpx
@ -9,7 +9,11 @@ from pydantic import BaseModel
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall from litellm.types.llms.openai import ChatCompletionAssistantToolCall
@ -205,6 +209,7 @@ def get_ollama_response( # noqa: PLR0915
api_key: Optional[str] = None, api_key: Optional[str] = None,
acompletion: bool = False, acompletion: bool = False,
encoding=None, encoding=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
if api_base.endswith("/api/chat"): if api_base.endswith("/api/chat"):
url = api_base url = api_base
@ -301,7 +306,11 @@ def get_ollama_response( # noqa: PLR0915
headers: Optional[dict] = None headers: Optional[dict] = None
if api_key is not None: if api_key is not None:
headers = {"Authorization": "Bearer {}".format(api_key)} headers = {"Authorization": "Bearer {}".format(api_key)}
response = litellm.module_level_client.post(
sync_client = litellm.module_level_client
if client is not None and isinstance(client, HTTPHandler):
sync_client = client
response = sync_client.post(
url=url, url=url,
json=data, json=data,
headers=headers, headers=headers,

View file

@ -2856,6 +2856,7 @@ def completion( # type: ignore # noqa: PLR0915
acompletion=acompletion, acompletion=acompletion,
model_response=model_response, model_response=model_response,
encoding=encoding, encoding=encoding,
client=client,
) )
if acompletion is True or optional_params.get("stream", False) is True: if acompletion is True or optional_params.get("stream", False) is True:
return generator return generator

View file

@ -1,13 +1,4 @@
model_list: model_list:
- model_name: openai/gpt-4o - model_name: llama3.2-vision
litellm_params: litellm_params:
model: openai/gpt-4o model: ollama/llama3.2-vision
api_key: os.environ/OPENAI_API_KEY
files_settings:
- custom_llm_provider: azure
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
general_settings:
store_prompts_in_spend_logs: true

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import json
import os import os
import sys import sys
import traceback import traceback
@ -76,6 +77,45 @@ def test_ollama_json_mode():
# test_ollama_json_mode() # test_ollama_json_mode()
def test_ollama_vision_model():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
from unittest.mock import patch
with patch.object(client, "post") as mock_post:
try:
litellm.completion(
model="ollama/llama3.2-vision:11b",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://dummyimage.com/100/100/fff&text=Test+image"
},
},
],
}
],
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called()
print(mock_post.call_args.kwargs)
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert json_data["model"] == "llama3.2-vision:11b"
assert "images" in json_data
assert "prompt" in json_data
assert json_data["prompt"].startswith("### User:\n")
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text") mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")