mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(main.py): fix openai message for assistant msg if role is missing - openai allows this
Fixes https://github.com/BerriAI/litellm/issues/8661
This commit is contained in:
parent
de3989dbc5
commit
ac6e503461
3 changed files with 143 additions and 5 deletions
|
@ -94,7 +94,7 @@ from litellm.utils import (
|
|||
read_config_args,
|
||||
supports_httpx_timeout,
|
||||
token_counter,
|
||||
validate_chat_completion_messages,
|
||||
validate_and_fix_openai_messages,
|
||||
validate_chat_completion_tool_choice,
|
||||
)
|
||||
|
||||
|
@ -851,7 +851,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
if model is None:
|
||||
raise ValueError("model param not passed in.")
|
||||
# validate messages
|
||||
messages = validate_chat_completion_messages(messages=messages)
|
||||
messages = validate_and_fix_openai_messages(messages=messages)
|
||||
# validate tool_choice
|
||||
tool_choice = validate_chat_completion_tool_choice(tool_choice=tool_choice)
|
||||
######### unpacking kwargs #####################
|
||||
|
|
|
@ -5932,6 +5932,18 @@ def convert_to_dict(message: Union[BaseModel, dict]) -> dict:
|
|||
)
|
||||
|
||||
|
||||
def validate_and_fix_openai_messages(messages: List):
|
||||
"""
|
||||
Ensures all messages are valid OpenAI chat completion messages.
|
||||
|
||||
Handles missing role for assistant messages.
|
||||
"""
|
||||
for message in messages:
|
||||
if "role" not in message:
|
||||
message["role"] = "assistant"
|
||||
return validate_chat_completion_messages(messages=messages)
|
||||
|
||||
|
||||
def validate_chat_completion_messages(messages: List[AllMessageValues]):
|
||||
"""
|
||||
Ensures all messages are valid OpenAI chat completion messages.
|
||||
|
@ -6282,11 +6294,18 @@ def get_end_user_id_for_cost_tracking(
|
|||
return None
|
||||
return end_user_id
|
||||
|
||||
def should_use_cohere_v1_client(api_base: Optional[str], present_version_params: List[str]):
|
||||
|
||||
def should_use_cohere_v1_client(
|
||||
api_base: Optional[str], present_version_params: List[str]
|
||||
):
|
||||
if not api_base:
|
||||
return False
|
||||
uses_v1_params = ("max_chunks_per_doc" in present_version_params) and ('max_tokens_per_doc' not in present_version_params)
|
||||
return api_base.endswith("/v1/rerank") or (uses_v1_params and not api_base.endswith("/v2/rerank"))
|
||||
uses_v1_params = ("max_chunks_per_doc" in present_version_params) and (
|
||||
"max_tokens_per_doc" not in present_version_params
|
||||
)
|
||||
return api_base.endswith("/v1/rerank") or (
|
||||
uses_v1_params and not api_base.endswith("/v2/rerank")
|
||||
)
|
||||
|
||||
|
||||
def is_prompt_caching_valid_prompt(
|
||||
|
|
119
tests/litellm/test_main.py
Normal file
119
tests/litellm/test_main.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_api_response():
|
||||
mock_response_data = {
|
||||
"id": "chatcmpl-B0W3vmiM78Xkgx7kI7dr7PC949DMS",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"message": {
|
||||
"content": "",
|
||||
"refusal": None,
|
||||
"role": "assistant",
|
||||
"audio": None,
|
||||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"created": 1739462947,
|
||||
"model": "gpt-4o-mini-2024-07-18",
|
||||
"object": "chat.completion",
|
||||
"service_tier": "default",
|
||||
"system_fingerprint": "fp_bd83329f63",
|
||||
"usage": {
|
||||
"completion_tokens": 1,
|
||||
"prompt_tokens": 121,
|
||||
"total_tokens": 122,
|
||||
"completion_tokens_details": {
|
||||
"accepted_prediction_tokens": 0,
|
||||
"audio_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
"rejected_prediction_tokens": 0,
|
||||
},
|
||||
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
||||
},
|
||||
}
|
||||
|
||||
return mock_response_data
|
||||
|
||||
|
||||
def test_completion_missing_role(openai_api_response):
|
||||
from openai import OpenAI
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
client = OpenAI(api_key="test_api_key")
|
||||
|
||||
mock_raw_response = MagicMock()
|
||||
mock_raw_response.headers = {
|
||||
"x-request-id": "123",
|
||||
"openai-organization": "org-123",
|
||||
"x-ratelimit-limit-requests": "100",
|
||||
"x-ratelimit-remaining-requests": "99",
|
||||
}
|
||||
mock_raw_response.parse.return_value = ModelResponse(**openai_api_response)
|
||||
|
||||
print(f"openai_api_response: {openai_api_response}")
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create", mock_raw_response
|
||||
) as mock_create:
|
||||
litellm.completion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hey"},
|
||||
{
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_m0vFJjQmTH1McvaHBPR2YFwY",
|
||||
"function": {
|
||||
"arguments": '{"input": "dksjsdkjdhskdjshdskhjkhlk"}',
|
||||
"name": "tool_name",
|
||||
},
|
||||
"type": "function",
|
||||
"index": 0,
|
||||
},
|
||||
{
|
||||
"id": "call_Vw6RaqV2n5aaANXEdp5pYxo2",
|
||||
"function": {
|
||||
"arguments": '{"input": "jkljlkjlkjlkjlk"}',
|
||||
"name": "tool_name",
|
||||
},
|
||||
"type": "function",
|
||||
"index": 1,
|
||||
},
|
||||
{
|
||||
"id": "call_hBIKwldUEGlNh6NlSXil62K4",
|
||||
"function": {
|
||||
"arguments": '{"input": "jkjlkjlkjlkj;lj"}',
|
||||
"name": "tool_name",
|
||||
},
|
||||
"type": "function",
|
||||
"index": 2,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
client=client,
|
||||
)
|
||||
|
||||
mock_create.assert_called_once()
|
Loading…
Add table
Add a link
Reference in a new issue