mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Litellm dev 03 08 2025 p3 (#9089)
* feat(ollama_chat.py): pass down http client to ollama_chat enables easier testing * fix(factory.py): fix passing images to ollama's `/api/generate` endpoint Fixes https://github.com/BerriAI/litellm/issues/6683 * fix(factory.py): fix ollama pt to handle templating correctly
This commit is contained in:
parent
93273723cd
commit
e00d4fb18c
5 changed files with 165 additions and 52 deletions
|
@ -187,53 +187,125 @@ def ollama_pt(
|
||||||
final_prompt_value="### Response:",
|
final_prompt_value="### Response:",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
elif "llava" in model:
|
|
||||||
prompt = ""
|
|
||||||
images = []
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message["content"], str):
|
|
||||||
prompt += message["content"]
|
|
||||||
elif isinstance(message["content"], list):
|
|
||||||
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
|
||||||
for element in message["content"]:
|
|
||||||
if isinstance(element, dict):
|
|
||||||
if element["type"] == "text":
|
|
||||||
prompt += element["text"]
|
|
||||||
elif element["type"] == "image_url":
|
|
||||||
base64_image = convert_to_ollama_image(
|
|
||||||
element["image_url"]["url"]
|
|
||||||
)
|
|
||||||
images.append(base64_image)
|
|
||||||
return {"prompt": prompt, "images": images}
|
|
||||||
else:
|
else:
|
||||||
|
user_message_types = {"user", "tool", "function"}
|
||||||
|
msg_i = 0
|
||||||
|
images = []
|
||||||
prompt = ""
|
prompt = ""
|
||||||
for message in messages:
|
while msg_i < len(messages):
|
||||||
role = message["role"]
|
init_msg_i = msg_i
|
||||||
content = message.get("content", "")
|
user_content_str = ""
|
||||||
|
## MERGE CONSECUTIVE USER CONTENT ##
|
||||||
|
while (
|
||||||
|
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
||||||
|
):
|
||||||
|
msg_content = messages[msg_i].get("content")
|
||||||
|
if msg_content:
|
||||||
|
if isinstance(msg_content, list):
|
||||||
|
for m in msg_content:
|
||||||
|
if m.get("type", "") == "image_url":
|
||||||
|
if isinstance(m["image_url"], str):
|
||||||
|
images.append(m["image_url"])
|
||||||
|
elif isinstance(m["image_url"], dict):
|
||||||
|
images.append(m["image_url"]["url"])
|
||||||
|
elif m.get("type", "") == "text":
|
||||||
|
user_content_str += m["text"]
|
||||||
|
else:
|
||||||
|
# Tool message content will always be a string
|
||||||
|
user_content_str += msg_content
|
||||||
|
|
||||||
if "tool_calls" in message:
|
msg_i += 1
|
||||||
tool_calls = []
|
|
||||||
|
|
||||||
for call in message["tool_calls"]:
|
if user_content_str:
|
||||||
call_id: str = call["id"]
|
prompt += f"### User:\n{user_content_str}\n\n"
|
||||||
function_name: str = call["function"]["name"]
|
|
||||||
arguments = json.loads(call["function"]["arguments"])
|
|
||||||
|
|
||||||
tool_calls.append(
|
assistant_content_str = ""
|
||||||
{
|
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||||
"id": call_id,
|
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||||
"type": "function",
|
msg_content = messages[msg_i].get("content")
|
||||||
"function": {"name": function_name, "arguments": arguments},
|
if msg_content:
|
||||||
}
|
if isinstance(msg_content, list):
|
||||||
|
for m in msg_content:
|
||||||
|
if m.get("type", "") == "text":
|
||||||
|
assistant_content_str += m["text"]
|
||||||
|
elif isinstance(msg_content, str):
|
||||||
|
# Tool message content will always be a string
|
||||||
|
assistant_content_str += msg_content
|
||||||
|
|
||||||
|
tool_calls = messages[msg_i].get("tool_calls")
|
||||||
|
ollama_tool_calls = []
|
||||||
|
if tool_calls:
|
||||||
|
for call in tool_calls:
|
||||||
|
call_id: str = call["id"]
|
||||||
|
function_name: str = call["function"]["name"]
|
||||||
|
arguments = json.loads(call["function"]["arguments"])
|
||||||
|
|
||||||
|
ollama_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if ollama_tool_calls:
|
||||||
|
assistant_content_str += (
|
||||||
|
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
|
msg_i += 1
|
||||||
|
|
||||||
elif "tool_call_id" in message:
|
if assistant_content_str:
|
||||||
prompt += f"### User:\n{message['content']}\n\n"
|
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
|
||||||
|
|
||||||
elif content:
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
prompt += f"### {role.capitalize()}:\n{content}\n\n"
|
raise litellm.BadRequestError(
|
||||||
|
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="ollama",
|
||||||
|
)
|
||||||
|
# prompt = ""
|
||||||
|
# images = []
|
||||||
|
# for message in messages:
|
||||||
|
# if isinstance(message["content"], str):
|
||||||
|
# prompt += message["content"]
|
||||||
|
# elif isinstance(message["content"], list):
|
||||||
|
# # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||||
|
# for element in message["content"]:
|
||||||
|
# if isinstance(element, dict):
|
||||||
|
# if element["type"] == "text":
|
||||||
|
# prompt += element["text"]
|
||||||
|
# elif element["type"] == "image_url":
|
||||||
|
# base64_image = convert_to_ollama_image(
|
||||||
|
# element["image_url"]["url"]
|
||||||
|
# )
|
||||||
|
# images.append(base64_image)
|
||||||
|
|
||||||
|
# if "tool_calls" in message:
|
||||||
|
# tool_calls = []
|
||||||
|
|
||||||
|
# for call in message["tool_calls"]:
|
||||||
|
# call_id: str = call["id"]
|
||||||
|
# function_name: str = call["function"]["name"]
|
||||||
|
# arguments = json.loads(call["function"]["arguments"])
|
||||||
|
|
||||||
|
# tool_calls.append(
|
||||||
|
# {
|
||||||
|
# "id": call_id,
|
||||||
|
# "type": "function",
|
||||||
|
# "function": {"name": function_name, "arguments": arguments},
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
|
||||||
|
# prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
|
||||||
|
|
||||||
|
# elif "tool_call_id" in message:
|
||||||
|
# prompt += f"### User:\n{message['content']}\n\n"
|
||||||
|
|
||||||
|
return {"prompt": prompt, "images": images}
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -9,7 +9,11 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
||||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||||
|
@ -205,6 +209,7 @@ def get_ollama_response( # noqa: PLR0915
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
if api_base.endswith("/api/chat"):
|
if api_base.endswith("/api/chat"):
|
||||||
url = api_base
|
url = api_base
|
||||||
|
@ -301,7 +306,11 @@ def get_ollama_response( # noqa: PLR0915
|
||||||
headers: Optional[dict] = None
|
headers: Optional[dict] = None
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers = {"Authorization": "Bearer {}".format(api_key)}
|
headers = {"Authorization": "Bearer {}".format(api_key)}
|
||||||
response = litellm.module_level_client.post(
|
|
||||||
|
sync_client = litellm.module_level_client
|
||||||
|
if client is not None and isinstance(client, HTTPHandler):
|
||||||
|
sync_client = client
|
||||||
|
response = sync_client.post(
|
||||||
url=url,
|
url=url,
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
|
|
@ -2856,6 +2856,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
if acompletion is True or optional_params.get("stream", False) is True:
|
if acompletion is True or optional_params.get("stream", False) is True:
|
||||||
return generator
|
return generator
|
||||||
|
|
|
@ -1,13 +1,4 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: openai/gpt-4o
|
- model_name: llama3.2-vision
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-4o
|
model: ollama/llama3.2-vision
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
|
||||||
|
|
||||||
files_settings:
|
|
||||||
- custom_llm_provider: azure
|
|
||||||
api_base: os.environ/AZURE_API_BASE
|
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
store_prompts_in_spend_logs: true
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -76,6 +77,45 @@ def test_ollama_json_mode():
|
||||||
# test_ollama_json_mode()
|
# test_ollama_json_mode()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_vision_model():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
litellm.completion(
|
||||||
|
model="ollama/llama3.2-vision:11b",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Whats in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://dummyimage.com/100/100/fff&text=Test+image"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
mock_post.assert_called()
|
||||||
|
|
||||||
|
print(mock_post.call_args.kwargs)
|
||||||
|
|
||||||
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert json_data["model"] == "llama3.2-vision:11b"
|
||||||
|
assert "images" in json_data
|
||||||
|
assert "prompt" in json_data
|
||||||
|
assert json_data["prompt"].startswith("### User:\n")
|
||||||
|
|
||||||
|
|
||||||
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
|
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue