From ac6e503461b038d85ef1cfa29ae15e29d6b189c1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 24 Feb 2025 10:43:31 -0800 Subject: [PATCH] fix(main.py): fix openai message for assistant msg if role is missing - openai allows this Fixes https://github.com/BerriAI/litellm/issues/8661 --- litellm/main.py | 4 +- litellm/utils.py | 25 +++++++- tests/litellm/test_main.py | 119 +++++++++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 5 deletions(-) create mode 100644 tests/litellm/test_main.py diff --git a/litellm/main.py b/litellm/main.py index ece484f1f2..c52bfd7c92 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -94,7 +94,7 @@ from litellm.utils import ( read_config_args, supports_httpx_timeout, token_counter, - validate_chat_completion_messages, + validate_and_fix_openai_messages, validate_chat_completion_tool_choice, ) @@ -851,7 +851,7 @@ def completion( # type: ignore # noqa: PLR0915 if model is None: raise ValueError("model param not passed in.") # validate messages - messages = validate_chat_completion_messages(messages=messages) + messages = validate_and_fix_openai_messages(messages=messages) # validate tool_choice tool_choice = validate_chat_completion_tool_choice(tool_choice=tool_choice) ######### unpacking kwargs ##################### diff --git a/litellm/utils.py b/litellm/utils.py index facc2ac59b..4113361b69 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5932,6 +5932,18 @@ def convert_to_dict(message: Union[BaseModel, dict]) -> dict: ) +def validate_and_fix_openai_messages(messages: List): + """ + Ensures all messages are valid OpenAI chat completion messages. + + Handles missing role for assistant messages. + """ + for message in messages: + if "role" not in message: + message["role"] = "assistant" + return validate_chat_completion_messages(messages=messages) + + def validate_chat_completion_messages(messages: List[AllMessageValues]): """ Ensures all messages are valid OpenAI chat completion messages. @@ -6282,11 +6294,18 @@ def get_end_user_id_for_cost_tracking( return None return end_user_id -def should_use_cohere_v1_client(api_base: Optional[str], present_version_params: List[str]): + +def should_use_cohere_v1_client( + api_base: Optional[str], present_version_params: List[str] +): if not api_base: return False - uses_v1_params = ("max_chunks_per_doc" in present_version_params) and ('max_tokens_per_doc' not in present_version_params) - return api_base.endswith("/v1/rerank") or (uses_v1_params and not api_base.endswith("/v2/rerank")) + uses_v1_params = ("max_chunks_per_doc" in present_version_params) and ( + "max_tokens_per_doc" not in present_version_params + ) + return api_base.endswith("/v1/rerank") or ( + uses_v1_params and not api_base.endswith("/v2/rerank") + ) def is_prompt_caching_valid_prompt( diff --git a/tests/litellm/test_main.py b/tests/litellm/test_main.py new file mode 100644 index 0000000000..b838434915 --- /dev/null +++ b/tests/litellm/test_main.py @@ -0,0 +1,119 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +from unittest.mock import MagicMock, patch + +import litellm + + +@pytest.fixture +def openai_api_response(): + mock_response_data = { + "id": "chatcmpl-B0W3vmiM78Xkgx7kI7dr7PC949DMS", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": None, + "message": { + "content": "", + "refusal": None, + "role": "assistant", + "audio": None, + "function_call": None, + "tool_calls": None, + }, + } + ], + "created": 1739462947, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "service_tier": "default", + "system_fingerprint": "fp_bd83329f63", + "usage": { + "completion_tokens": 1, + "prompt_tokens": 121, + "total_tokens": 122, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + }, + } + + return mock_response_data + + +def test_completion_missing_role(openai_api_response): + from openai import OpenAI + + from litellm.types.utils import ModelResponse + + client = OpenAI(api_key="test_api_key") + + mock_raw_response = MagicMock() + mock_raw_response.headers = { + "x-request-id": "123", + "openai-organization": "org-123", + "x-ratelimit-limit-requests": "100", + "x-ratelimit-remaining-requests": "99", + } + mock_raw_response.parse.return_value = ModelResponse(**openai_api_response) + + print(f"openai_api_response: {openai_api_response}") + + with patch.object( + client.chat.completions.with_raw_response, "create", mock_raw_response + ) as mock_create: + litellm.completion( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": "Hey"}, + { + "content": "", + "tool_calls": [ + { + "id": "call_m0vFJjQmTH1McvaHBPR2YFwY", + "function": { + "arguments": '{"input": "dksjsdkjdhskdjshdskhjkhlk"}', + "name": "tool_name", + }, + "type": "function", + "index": 0, + }, + { + "id": "call_Vw6RaqV2n5aaANXEdp5pYxo2", + "function": { + "arguments": '{"input": "jkljlkjlkjlkjlk"}', + "name": "tool_name", + }, + "type": "function", + "index": 1, + }, + { + "id": "call_hBIKwldUEGlNh6NlSXil62K4", + "function": { + "arguments": '{"input": "jkjlkjlkjlkj;lj"}', + "name": "tool_name", + }, + "type": "function", + "index": 2, + }, + ], + }, + ], + client=client, + ) + + mock_create.assert_called_once()