diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 1fe305e417..e64d03b67d 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index d3da8727b3..9f46f02db4 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/main.py b/litellm/main.py index f045ffb816..076e4c7f40 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -60,7 +60,9 @@ from litellm.utils import ( ModelResponse, EmbeddingResponse, read_config_args, - RateLimitManager + RateLimitManager, + Choices, + Message ) ####### ENVIRONMENT VARIABLES ################### @@ -509,8 +511,12 @@ def completion( }, ) ## RESPONSE OBJECT - completion_response = response["choices"][0]["text"] - model_response["choices"][0]["message"]["content"] = completion_response + choices_list = [] + for idx, item in enumerate(response["choices"]): + message_obj = Message(content=item["text"]) + choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) + choices_list.append(choice_obj) + model_response["choices"] = choices_list model_response["created"] = response.get("created", time.time()) model_response["model"] = model model_response["usage"] = response.get("usage", 0) diff --git a/litellm/tests/test_provider_specific_config.py b/litellm/tests/test_provider_specific_config.py index 9cf7ecb045..0e1ac9283f 100644 --- a/litellm/tests/test_provider_specific_config.py +++ b/litellm/tests/test_provider_specific_config.py @@ -498,6 +498,11 @@ def openai_text_completion_test(): print(f"response_2_text: {response_2_text}") assert len(response_2_text) < len(response_1_text) + + response_3 = litellm.completion(model="text-davinci-003", + messages=[{ "content": "Hello, how are you?","role": "user"}], + n=2) + assert len(response_3.choices) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}")