diff --git a/litellm/main.py b/litellm/main.py index 9dd6186312..d208966311 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -9,6 +9,7 @@ import os, openai, sys, json, inspect, uuid, datetime, threading from typing import Any, Literal, Union, BinaryIO +from typing_extensions import overload from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 70d6824e9e..45a53ca56a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -68,6 +68,51 @@ def test_completion_custom_provider_model_name(): pytest.fail(f"Error occurred: {e}") +def _openai_mock_response(*args, **kwargs) -> litellm.ModelResponse: + _data = { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": None, + "content": "\n\nHello there, how may I assist you today?", + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, + } + return litellm.ModelResponse(**_data) + + +def test_null_role_response(): + """ + Test if api returns 'null' role, 'assistant' role is still returned + """ + import openai + + openai_client = openai.OpenAI() + with patch.object( + openai_client.chat.completions, "create", side_effect=_openai_mock_response + ) as mock_response: + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey! how's it going?"}], + client=openai_client, + ) + print(f"response: {response}") + + assert response.id == "chatcmpl-123" + + assert response.choices[0].message.role == "assistant" + + def test_completion_azure_command_r(): try: litellm.set_verbose = True diff --git a/litellm/utils.py b/litellm/utils.py index e2bc8bec18..5d7cf63456 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7709,7 +7709,7 @@ def convert_to_model_response_object( for idx, choice in enumerate(response_object["choices"]): message = Message( content=choice["message"].get("content", None), - role=choice["message"]["role"], + role=choice["message"]["role"] or "assistant", function_call=choice["message"].get("function_call", None), tool_calls=choice["message"].get("tool_calls", None), )