diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index 615df2cb89..df7aa2cbd0 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -187,53 +187,125 @@ def ollama_pt( final_prompt_value="### Response:", messages=messages, ) - elif "llava" in model: - prompt = "" - images = [] - for message in messages: - if isinstance(message["content"], str): - prompt += message["content"] - elif isinstance(message["content"], list): - # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models - for element in message["content"]: - if isinstance(element, dict): - if element["type"] == "text": - prompt += element["text"] - elif element["type"] == "image_url": - base64_image = convert_to_ollama_image( - element["image_url"]["url"] - ) - images.append(base64_image) - return {"prompt": prompt, "images": images} else: + user_message_types = {"user", "tool", "function"} + msg_i = 0 + images = [] prompt = "" - for message in messages: - role = message["role"] - content = message.get("content", "") + while msg_i < len(messages): + init_msg_i = msg_i + user_content_str = "" + ## MERGE CONSECUTIVE USER CONTENT ## + while ( + msg_i < len(messages) and messages[msg_i]["role"] in user_message_types + ): + msg_content = messages[msg_i].get("content") + if msg_content: + if isinstance(msg_content, list): + for m in msg_content: + if m.get("type", "") == "image_url": + if isinstance(m["image_url"], str): + images.append(m["image_url"]) + elif isinstance(m["image_url"], dict): + images.append(m["image_url"]["url"]) + elif m.get("type", "") == "text": + user_content_str += m["text"] + else: + # Tool message content will always be a string + user_content_str += msg_content - if "tool_calls" in message: - tool_calls = [] + msg_i += 1 - for call in message["tool_calls"]: - call_id: str = call["id"] - function_name: str = call["function"]["name"] - arguments = json.loads(call["function"]["arguments"]) + if user_content_str: + prompt += f"### User:\n{user_content_str}\n\n" - tool_calls.append( - { - "id": call_id, - "type": "function", - "function": {"name": function_name, "arguments": arguments}, - } + assistant_content_str = "" + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + msg_content = messages[msg_i].get("content") + if msg_content: + if isinstance(msg_content, list): + for m in msg_content: + if m.get("type", "") == "text": + assistant_content_str += m["text"] + elif isinstance(msg_content, str): + # Tool message content will always be a string + assistant_content_str += msg_content + + tool_calls = messages[msg_i].get("tool_calls") + ollama_tool_calls = [] + if tool_calls: + for call in tool_calls: + call_id: str = call["id"] + function_name: str = call["function"]["name"] + arguments = json.loads(call["function"]["arguments"]) + + ollama_tool_calls.append( + { + "id": call_id, + "type": "function", + "function": { + "name": function_name, + "arguments": arguments, + }, + } + ) + + if ollama_tool_calls: + assistant_content_str += ( + f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}" ) - prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n" + msg_i += 1 - elif "tool_call_id" in message: - prompt += f"### User:\n{message['content']}\n\n" + if assistant_content_str: + prompt += f"### Assistant:\n{assistant_content_str}\n\n" - elif content: - prompt += f"### {role.capitalize()}:\n{content}\n\n" + if msg_i == init_msg_i: # prevent infinite loops + raise litellm.BadRequestError( + message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}", + model=model, + llm_provider="ollama", + ) + # prompt = "" + # images = [] + # for message in messages: + # if isinstance(message["content"], str): + # prompt += message["content"] + # elif isinstance(message["content"], list): + # # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models + # for element in message["content"]: + # if isinstance(element, dict): + # if element["type"] == "text": + # prompt += element["text"] + # elif element["type"] == "image_url": + # base64_image = convert_to_ollama_image( + # element["image_url"]["url"] + # ) + # images.append(base64_image) + + # if "tool_calls" in message: + # tool_calls = [] + + # for call in message["tool_calls"]: + # call_id: str = call["id"] + # function_name: str = call["function"]["name"] + # arguments = json.loads(call["function"]["arguments"]) + + # tool_calls.append( + # { + # "id": call_id, + # "type": "function", + # "function": {"name": function_name, "arguments": arguments}, + # } + # ) + + # prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n" + + # elif "tool_call_id" in message: + # prompt += f"### User:\n{message['content']}\n\n" + + return {"prompt": prompt, "images": images} return prompt diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index aea1f38951..6f421680b4 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -1,7 +1,7 @@ import json import time import uuid -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import aiohttp import httpx @@ -9,7 +9,11 @@ from pydantic import BaseModel import litellm from litellm import verbose_logger -from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction from litellm.types.llms.openai import ChatCompletionAssistantToolCall @@ -205,6 +209,7 @@ def get_ollama_response( # noqa: PLR0915 api_key: Optional[str] = None, acompletion: bool = False, encoding=None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): if api_base.endswith("/api/chat"): url = api_base @@ -301,7 +306,11 @@ def get_ollama_response( # noqa: PLR0915 headers: Optional[dict] = None if api_key is not None: headers = {"Authorization": "Bearer {}".format(api_key)} - response = litellm.module_level_client.post( + + sync_client = litellm.module_level_client + if client is not None and isinstance(client, HTTPHandler): + sync_client = client + response = sync_client.post( url=url, json=data, headers=headers, diff --git a/litellm/main.py b/litellm/main.py index 1699e79cf7..5aa5653cca 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2856,6 +2856,7 @@ def completion( # type: ignore # noqa: PLR0915 acompletion=acompletion, model_response=model_response, encoding=encoding, + client=client, ) if acompletion is True or optional_params.get("stream", False) is True: return generator diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 8a46573955..eac1e6a6da 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,13 +1,4 @@ model_list: - - model_name: openai/gpt-4o + - model_name: llama3.2-vision litellm_params: - model: openai/gpt-4o - api_key: os.environ/OPENAI_API_KEY - -files_settings: - - custom_llm_provider: azure - api_base: os.environ/AZURE_API_BASE - api_key: os.environ/AZURE_API_KEY - -general_settings: - store_prompts_in_spend_logs: true \ No newline at end of file + model: ollama/llama3.2-vision \ No newline at end of file diff --git a/tests/local_testing/test_ollama.py b/tests/local_testing/test_ollama.py index 81cd331263..09c50315e0 100644 --- a/tests/local_testing/test_ollama.py +++ b/tests/local_testing/test_ollama.py @@ -1,4 +1,5 @@ import asyncio +import json import os import sys import traceback @@ -76,6 +77,45 @@ def test_ollama_json_mode(): # test_ollama_json_mode() +def test_ollama_vision_model(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + from unittest.mock import patch + + with patch.object(client, "post") as mock_post: + try: + litellm.completion( + model="ollama/llama3.2-vision:11b", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://dummyimage.com/100/100/fff&text=Test+image" + }, + }, + ], + } + ], + client=client, + ) + except Exception as e: + print(e) + mock_post.assert_called() + + print(mock_post.call_args.kwargs) + + json_data = json.loads(mock_post.call_args.kwargs["data"]) + assert json_data["model"] == "llama3.2-vision:11b" + assert "images" in json_data + assert "prompt" in json_data + assert json_data["prompt"].startswith("### User:\n") + + mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")