diff --git a/litellm/llms/text_completion_codestral.py b/litellm/llms/text_completion_codestral.py index 124c840c3..a46b67b12 100644 --- a/litellm/llms/text_completion_codestral.py +++ b/litellm/llms/text_completion_codestral.py @@ -17,6 +17,7 @@ from litellm.utils import ( Choices, ) from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.types.llms.databricks import GenericStreamingChunk import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler @@ -160,6 +161,39 @@ class MistralTextCompletionConfig: return optional_params + def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: + text = "" + is_finished = False + finish_reason = None + logprobs = None + + chunk_data = chunk_data.replace("data:", "") + chunk_data = chunk_data.strip() + if len(chunk_data) == 0 or chunk_data == "[DONE]": + return { + "text": "", + "is_finished": is_finished, + "finish_reason": finish_reason, + } + chunk_data_dict = json.loads(chunk_data) + original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True) + _choices = chunk_data_dict.get("choices", []) or [] + _choice = _choices[0] + text = _choice.get("delta", {}).get("content", "") + + if _choice.get("finish_reason") is not None: + is_finished = True + finish_reason = _choice.get("finish_reason") + logprobs = _choice.get("logprobs") + + return GenericStreamingChunk( + text=text, + original_chunk=original_chunk, + is_finished=is_finished, + finish_reason=finish_reason, + logprobs=logprobs, + ) + class CodestralTextCompletion(BaseLLM): def __init__(self) -> None: @@ -452,7 +486,7 @@ class CodestralTextCompletion(BaseLLM): logging_obj=logging_obj, ), model=model, - custom_llm_provider="codestral", + custom_llm_provider="text-completion-codestral", logging_obj=logging_obj, ) return streamwrapper diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index 1ddd8ea6b..bfd292a51 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -4105,3 +4105,40 @@ async def test_completion_codestral_fim_api(): # assert cost > 0.0 except Exception as e: pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_completion_codestral_fim_api_stream(): + try: + from litellm._logging import verbose_logger + import logging + + litellm.set_verbose = False + + # verbose_logger.setLevel(level=logging.DEBUG) + response = await litellm.atext_completion( + model="text-completion-codestral/codestral-2405", + prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():", + suffix="return True", + temperature=0, + top_p=1, + stream=True, + seed=10, + stop=["return"], + ) + + full_response = "" + # Add any assertions here to check the response + async for chunk in response: + print(chunk) + full_response += chunk.get("choices")[0].get("text") or "" + + print("full_response", full_response) + + assert len(full_response) > 2 # we at least have a few chars in response :) + + # cost = litellm.completion_cost(completion_response=response) + # print("cost to make mistral completion=", cost) + # assert cost > 0.0 + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index f66077d7a..5da62505a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8546,6 +8546,25 @@ class CustomStreamWrapper: completion_tokens=response_obj["usage"].completion_tokens, total_tokens=response_obj["usage"].total_tokens, ) + elif self.custom_llm_provider == "text-completion-codestral": + response_obj = litellm.MistralTextCompletionConfig()._chunk_parser( + chunk + ) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + if ( + self.stream_options + and self.stream_options.get("include_usage", False) == True + and response_obj["usage"] is not None + ): + self.sent_stream_usage = True + model_response.usage = litellm.Usage( + prompt_tokens=response_obj["usage"].prompt_tokens, + completion_tokens=response_obj["usage"].completion_tokens, + total_tokens=response_obj["usage"].total_tokens, + ) elif self.custom_llm_provider == "databricks": response_obj = litellm.DatabricksConfig()._chunk_parser(chunk) completion_obj["content"] = response_obj["text"] @@ -9019,6 +9038,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "azure" or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "text-completion-openai" + or self.custom_llm_provider == "text-completion-codestral" or self.custom_llm_provider == "azure_text" or self.custom_llm_provider == "anthropic" or self.custom_llm_provider == "anthropic_text"