From c1981f5534ec8b10d2bda6a9f51c6e0da68d897b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 17 Jun 2024 15:01:26 -0700 Subject: [PATCH] fix text completion response from codestral --- litellm/llms/text_completion_codestral.py | 41 +++++++++++++++++++++-- litellm/main.py | 11 +++--- litellm/tests/test_text_completion.py | 3 ++ 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/litellm/llms/text_completion_codestral.py b/litellm/llms/text_completion_codestral.py index a46b67b126..e732706b4a 100644 --- a/litellm/llms/text_completion_codestral.py +++ b/litellm/llms/text_completion_codestral.py @@ -257,7 +257,7 @@ class CodestralTextCompletion(BaseLLM): original_response=response.text, additional_args={"complete_input_dict": data}, ) - print_verbose(f"raw model_response: {response.text}") + print_verbose(f"codestral api: raw model_response: {response.text}") ## RESPONSE OBJECT if response.status_code != 200: raise TextCompletionCodestralError( @@ -269,7 +269,44 @@ class CodestralTextCompletion(BaseLLM): except: raise TextCompletionCodestralError(message=response.text, status_code=422) - _response = litellm.TextCompletionResponse(**completion_response) + _original_choices = completion_response.get("choices", []) + _choices: List[litellm.utils.TextChoices] = [] + for choice in _original_choices: + # This is what 1 choice looks like from codestral API + # { + # "index": 0, + # "message": { + # "role": "assistant", + # "content": "\n assert is_odd(1)\n assert", + # "tool_calls": null + # }, + # "finish_reason": "length", + # "logprobs": null + # } + _finish_reason = None + _index = 0 + _text = None + _logprobs = None + + _choice_message = choice.get("message", {}) + _choice = litellm.utils.TextChoices( + finish_reason=choice.get("finish_reason"), + index=choice.get("index"), + text=_choice_message.get("content"), + logprobs=choice.get("logprobs"), + ) + + _choices.append(_choice) + + _response = litellm.TextCompletionResponse( + id=completion_response.get("id"), + choices=_choices, + created=completion_response.get("created"), + model=completion_response.get("model"), + usage=completion_response.get("usage"), + stream=False, + object=completion_response.get("object"), + ) return _response def completion( diff --git a/litellm/main.py b/litellm/main.py index 0540d29cde..31809ef7fb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -355,9 +355,10 @@ async def acompletion( else: response = init_response # type: ignore - if custom_llm_provider == "text-completion-openai" and isinstance( - response, TextCompletionResponse - ): + if ( + custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "text-completion-codestral" + ) and isinstance(response, TextCompletionResponse): response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( response_object=response, model_response_object=litellm.ModelResponse(), @@ -3458,7 +3459,9 @@ def embedding( ###### Text Completion ################ @client -async def atext_completion(*args, **kwargs): +async def atext_completion( + *args, **kwargs +) -> Union[TextCompletionResponse, TextCompletionStreamWrapper]: """ Implemented to handle async streaming for the text completion endpoint """ diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index bfd292a511..61f649a224 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -4100,6 +4100,9 @@ async def test_completion_codestral_fim_api(): # Add any assertions here to check the response print(response) + assert response.choices[0].text is not None + assert len(response.choices[0].text) > 0 + # cost = litellm.completion_cost(completion_response=response) # print("cost to make mistral completion=", cost) # assert cost > 0.0