mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #9333 from BerriAI/litellm_dev_03_17_2025_p2
fix(ollama/completions/transformation.py): pass prompt, untemplated o…
This commit is contained in:
commit
cd5024f3b1
5 changed files with 50 additions and 6 deletions
|
@ -77,6 +77,16 @@ def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_str_from_messages(messages: List[AllMessageValues]) -> str:
|
||||||
|
"""
|
||||||
|
Converts a list of messages to a string
|
||||||
|
"""
|
||||||
|
text = ""
|
||||||
|
for message in messages:
|
||||||
|
text += convert_content_list_to_str(message=message)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def is_non_content_values_set(message: AllMessageValues) -> bool:
|
def is_non_content_values_set(message: AllMessageValues) -> bool:
|
||||||
ignore_keys = ["content", "role", "name"]
|
ignore_keys = ["content", "role", "name"]
|
||||||
return any(
|
return any(
|
||||||
|
|
|
@ -181,7 +181,7 @@ def _handle_ollama_system_message(
|
||||||
|
|
||||||
|
|
||||||
def ollama_pt(
|
def ollama_pt(
|
||||||
model, messages
|
model: str, messages: list
|
||||||
) -> Union[
|
) -> Union[
|
||||||
str, OllamaVisionModelObject
|
str, OllamaVisionModelObject
|
||||||
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||||
|
|
|
@ -6,6 +6,9 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
|
||||||
from httpx._models import Headers, Response
|
from httpx._models import Headers, Response
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
|
get_str_from_messages,
|
||||||
|
)
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
convert_to_ollama_image,
|
convert_to_ollama_image,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
|
@ -302,6 +305,8 @@ class OllamaConfig(BaseConfig):
|
||||||
custom_prompt_dict = (
|
custom_prompt_dict = (
|
||||||
litellm_params.get("custom_prompt_dict") or litellm.custom_prompt_dict
|
litellm_params.get("custom_prompt_dict") or litellm.custom_prompt_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
text_completion_request = litellm_params.get("text_completion")
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
@ -311,7 +316,9 @@ class OllamaConfig(BaseConfig):
|
||||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
else:
|
elif text_completion_request: # handle `/completions` requests
|
||||||
|
ollama_prompt = get_str_from_messages(messages=messages)
|
||||||
|
else: # handle `/chat/completions` requests
|
||||||
modified_prompt = ollama_pt(model=model, messages=messages)
|
modified_prompt = ollama_pt(model=model, messages=messages)
|
||||||
if isinstance(modified_prompt, dict):
|
if isinstance(modified_prompt, dict):
|
||||||
ollama_prompt, images = (
|
ollama_prompt, images = (
|
||||||
|
|
|
@ -485,17 +485,21 @@ async def test_streaming_with_usage_and_logging(sync_mode: bool):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
mock_log_success_event.assert_called_once()
|
mock_log_success_event.assert_called_once()
|
||||||
# mock_log_stream_event.assert_called()
|
# mock_log_stream_event.assert_called()
|
||||||
|
assert (
|
||||||
|
mock_log_success_event.call_args.kwargs["response_obj"].usage
|
||||||
|
== final_usage_block
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
mock_async_log_success_event.assert_called_once()
|
mock_async_log_success_event.assert_called_once()
|
||||||
# mock_async_log_stream_event.assert_called()
|
# mock_async_log_stream_event.assert_called()
|
||||||
|
assert (
|
||||||
|
mock_async_log_success_event.call_args.kwargs["response_obj"].usage
|
||||||
|
== final_usage_block
|
||||||
|
)
|
||||||
|
|
||||||
print(mock_log_success_event.call_args.kwargs.keys())
|
print(mock_log_success_event.call_args.kwargs.keys())
|
||||||
|
|
||||||
mock_log_success_event.call_args.kwargs[
|
|
||||||
"response_obj"
|
|
||||||
].usage == final_usage_block
|
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_handler_with_stop_chunk(
|
def test_streaming_handler_with_stop_chunk(
|
||||||
initialized_custom_stream_wrapper: CustomStreamWrapper,
|
initialized_custom_stream_wrapper: CustomStreamWrapper,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -4285,3 +4286,25 @@ def test_text_completion_with_echo(stream):
|
||||||
print(chunk)
|
print(chunk)
|
||||||
else:
|
else:
|
||||||
assert isinstance(response, TextCompletionResponse)
|
assert isinstance(response, TextCompletionResponse)
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_completion_ollama():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_call:
|
||||||
|
try:
|
||||||
|
response = litellm.text_completion(
|
||||||
|
model="ollama/llama3.1:8b",
|
||||||
|
prompt="hello",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
print(mock_call.call_args.kwargs)
|
||||||
|
json_data = json.loads(mock_call.call_args.kwargs["data"])
|
||||||
|
assert json_data["prompt"] == "hello"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue