diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index 6ce8faa5c6..c8745f5119 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -77,6 +77,16 @@ def convert_content_list_to_str(message: AllMessageValues) -> str: return texts +def get_str_from_messages(messages: List[AllMessageValues]) -> str: + """ + Converts a list of messages to a string + """ + text = "" + for message in messages: + text += convert_content_list_to_str(message=message) + return text + + def is_non_content_values_set(message: AllMessageValues) -> bool: ignore_keys = ["content", "role", "name"] return any( diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index ff366b2396..28e09d7ac8 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -181,7 +181,7 @@ def _handle_ollama_system_message( def ollama_pt( - model, messages + model: str, messages: list ) -> Union[ str, OllamaVisionModelObject ]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index 4a7a3556ae..b4db95cfa1 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -6,6 +6,9 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, from httpx._models import Headers, Response import litellm +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + get_str_from_messages, +) from litellm.litellm_core_utils.prompt_templates.factory import ( convert_to_ollama_image, custom_prompt, @@ -302,6 +305,8 @@ class OllamaConfig(BaseConfig): custom_prompt_dict = ( litellm_params.get("custom_prompt_dict") or litellm.custom_prompt_dict ) + + text_completion_request = litellm_params.get("text_completion") if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] @@ -311,7 +316,9 @@ class OllamaConfig(BaseConfig): final_prompt_value=model_prompt_details["final_prompt_value"], messages=messages, ) - else: + elif text_completion_request: # handle `/completions` requests + ollama_prompt = get_str_from_messages(messages=messages) + else: # handle `/chat/completions` requests modified_prompt = ollama_pt(model=model, messages=messages) if isinstance(modified_prompt, dict): ollama_prompt, images = ( diff --git a/tests/litellm/litellm_core_utils/test_streaming_handler.py b/tests/litellm/litellm_core_utils/test_streaming_handler.py index 31d541330c..75c4fc1035 100644 --- a/tests/litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/litellm/litellm_core_utils/test_streaming_handler.py @@ -485,17 +485,21 @@ async def test_streaming_with_usage_and_logging(sync_mode: bool): time.sleep(1) mock_log_success_event.assert_called_once() # mock_log_stream_event.assert_called() + assert ( + mock_log_success_event.call_args.kwargs["response_obj"].usage + == final_usage_block + ) else: await asyncio.sleep(1) mock_async_log_success_event.assert_called_once() # mock_async_log_stream_event.assert_called() + assert ( + mock_async_log_success_event.call_args.kwargs["response_obj"].usage + == final_usage_block + ) print(mock_log_success_event.call_args.kwargs.keys()) - mock_log_success_event.call_args.kwargs[ - "response_obj" - ].usage == final_usage_block - def test_streaming_handler_with_stop_chunk( initialized_custom_stream_wrapper: CustomStreamWrapper, diff --git a/tests/local_testing/test_text_completion.py b/tests/local_testing/test_text_completion.py index 11c43de2cc..c2cee53868 100644 --- a/tests/local_testing/test_text_completion.py +++ b/tests/local_testing/test_text_completion.py @@ -1,4 +1,5 @@ import asyncio +import json import os import sys import traceback @@ -4285,3 +4286,25 @@ def test_text_completion_with_echo(stream): print(chunk) else: assert isinstance(response, TextCompletionResponse) + + +def test_text_completion_ollama(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + + with patch.object(client, "post") as mock_call: + try: + response = litellm.text_completion( + model="ollama/llama3.1:8b", + prompt="hello", + client=client, + ) + print(response) + except Exception as e: + print(e) + + mock_call.assert_called_once() + print(mock_call.call_args.kwargs) + json_data = json.loads(mock_call.call_args.kwargs["data"]) + assert json_data["prompt"] == "hello"