Merge pull request #9333 from BerriAI/litellm_dev_03_17_2025_p2

fix(ollama/completions/transformation.py): pass prompt, untemplated o…
This commit is contained in:
Krish Dholakia 2025-03-17 21:48:30 -07:00 committed by GitHub
commit cd5024f3b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 50 additions and 6 deletions

View file

@ -77,6 +77,16 @@ def convert_content_list_to_str(message: AllMessageValues) -> str:
return texts return texts
def get_str_from_messages(messages: List[AllMessageValues]) -> str:
"""
Converts a list of messages to a string
"""
text = ""
for message in messages:
text += convert_content_list_to_str(message=message)
return text
def is_non_content_values_set(message: AllMessageValues) -> bool: def is_non_content_values_set(message: AllMessageValues) -> bool:
ignore_keys = ["content", "role", "name"] ignore_keys = ["content", "role", "name"]
return any( return any(

View file

@ -181,7 +181,7 @@ def _handle_ollama_system_message(
def ollama_pt( def ollama_pt(
model, messages model: str, messages: list
) -> Union[ ) -> Union[
str, OllamaVisionModelObject str, OllamaVisionModelObject
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template ]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template

View file

@ -6,6 +6,9 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
from httpx._models import Headers, Response from httpx._models import Headers, Response
import litellm import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_str_from_messages,
)
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_ollama_image, convert_to_ollama_image,
custom_prompt, custom_prompt,
@ -302,6 +305,8 @@ class OllamaConfig(BaseConfig):
custom_prompt_dict = ( custom_prompt_dict = (
litellm_params.get("custom_prompt_dict") or litellm.custom_prompt_dict litellm_params.get("custom_prompt_dict") or litellm.custom_prompt_dict
) )
text_completion_request = litellm_params.get("text_completion")
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
@ -311,7 +316,9 @@ class OllamaConfig(BaseConfig):
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages, messages=messages,
) )
else: elif text_completion_request: # handle `/completions` requests
ollama_prompt = get_str_from_messages(messages=messages)
else: # handle `/chat/completions` requests
modified_prompt = ollama_pt(model=model, messages=messages) modified_prompt = ollama_pt(model=model, messages=messages)
if isinstance(modified_prompt, dict): if isinstance(modified_prompt, dict):
ollama_prompt, images = ( ollama_prompt, images = (

View file

@ -485,17 +485,21 @@ async def test_streaming_with_usage_and_logging(sync_mode: bool):
time.sleep(1) time.sleep(1)
mock_log_success_event.assert_called_once() mock_log_success_event.assert_called_once()
# mock_log_stream_event.assert_called() # mock_log_stream_event.assert_called()
assert (
mock_log_success_event.call_args.kwargs["response_obj"].usage
== final_usage_block
)
else: else:
await asyncio.sleep(1) await asyncio.sleep(1)
mock_async_log_success_event.assert_called_once() mock_async_log_success_event.assert_called_once()
# mock_async_log_stream_event.assert_called() # mock_async_log_stream_event.assert_called()
assert (
mock_async_log_success_event.call_args.kwargs["response_obj"].usage
== final_usage_block
)
print(mock_log_success_event.call_args.kwargs.keys()) print(mock_log_success_event.call_args.kwargs.keys())
mock_log_success_event.call_args.kwargs[
"response_obj"
].usage == final_usage_block
def test_streaming_handler_with_stop_chunk( def test_streaming_handler_with_stop_chunk(
initialized_custom_stream_wrapper: CustomStreamWrapper, initialized_custom_stream_wrapper: CustomStreamWrapper,

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import json
import os import os
import sys import sys
import traceback import traceback
@ -4285,3 +4286,25 @@ def test_text_completion_with_echo(stream):
print(chunk) print(chunk)
else: else:
assert isinstance(response, TextCompletionResponse) assert isinstance(response, TextCompletionResponse)
def test_text_completion_ollama():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
with patch.object(client, "post") as mock_call:
try:
response = litellm.text_completion(
model="ollama/llama3.1:8b",
prompt="hello",
client=client,
)
print(response)
except Exception as e:
print(e)
mock_call.assert_called_once()
print(mock_call.call_args.kwargs)
json_data = json.loads(mock_call.call_args.kwargs["data"])
assert json_data["prompt"] == "hello"