fix text completion response from codestral

This commit is contained in:
Ishaan Jaff 2024-06-17 15:01:26 -07:00
parent 2057c68217
commit c1981f5534
3 changed files with 49 additions and 6 deletions

View file

@ -257,7 +257,7 @@ class CodestralTextCompletion(BaseLLM):
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
print_verbose(f"codestral api: raw model_response: {response.text}")
## RESPONSE OBJECT
if response.status_code != 200:
raise TextCompletionCodestralError(
@ -269,7 +269,44 @@ class CodestralTextCompletion(BaseLLM):
except:
raise TextCompletionCodestralError(message=response.text, status_code=422)
_response = litellm.TextCompletionResponse(**completion_response)
_original_choices = completion_response.get("choices", [])
_choices: List[litellm.utils.TextChoices] = []
for choice in _original_choices:
# This is what 1 choice looks like from codestral API
# {
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "\n assert is_odd(1)\n assert",
# "tool_calls": null
# },
# "finish_reason": "length",
# "logprobs": null
# }
_finish_reason = None
_index = 0
_text = None
_logprobs = None
_choice_message = choice.get("message", {})
_choice = litellm.utils.TextChoices(
finish_reason=choice.get("finish_reason"),
index=choice.get("index"),
text=_choice_message.get("content"),
logprobs=choice.get("logprobs"),
)
_choices.append(_choice)
_response = litellm.TextCompletionResponse(
id=completion_response.get("id"),
choices=_choices,
created=completion_response.get("created"),
model=completion_response.get("model"),
usage=completion_response.get("usage"),
stream=False,
object=completion_response.get("object"),
)
return _response
def completion(

View file

@ -355,9 +355,10 @@ async def acompletion(
else:
response = init_response # type: ignore
if custom_llm_provider == "text-completion-openai" and isinstance(
response, TextCompletionResponse
):
if (
custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "text-completion-codestral"
) and isinstance(response, TextCompletionResponse):
response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=response,
model_response_object=litellm.ModelResponse(),
@ -3458,7 +3459,9 @@ def embedding(
###### Text Completion ################
@client
async def atext_completion(*args, **kwargs):
async def atext_completion(
*args, **kwargs
) -> Union[TextCompletionResponse, TextCompletionStreamWrapper]:
"""
Implemented to handle async streaming for the text completion endpoint
"""

View file

@ -4100,6 +4100,9 @@ async def test_completion_codestral_fim_api():
# Add any assertions here to check the response
print(response)
assert response.choices[0].text is not None
assert len(response.choices[0].text) > 0
# cost = litellm.completion_cost(completion_response=response)
# print("cost to make mistral completion=", cost)
# assert cost > 0.0