diff --git a/litellm/main.py b/litellm/main.py index 993927398e..a0e2cc19cb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -59,6 +59,7 @@ encoding = tiktoken.get_encoding("cl100k_base") from litellm.utils import ( get_secret, CustomStreamWrapper, + TextCompletionStreamWrapper, ModelResponse, TextCompletionResponse, TextChoices, @@ -2031,6 +2032,9 @@ def text_completion( **kwargs, **optional_params, ) + if stream == True or kwargs.get("stream", False) == True: + response = TextCompletionStreamWrapper(completion_stream=response, model=model) + return response transformed_logprobs = None # only supported for TGI models diff --git a/litellm/utils.py b/litellm/utils.py index 639a6d73ad..6a01138252 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4156,6 +4156,43 @@ class CustomStreamWrapper: except StopIteration: raise StopAsyncIteration +class TextCompletionStreamWrapper: + def __init__(self, completion_stream, model): + self.completion_stream = completion_stream + self.model = model + + def __iter__(self): + return self + + def __aiter__(self): + return self + + def __next__(self): + # model_response = ModelResponse(stream=True, model=self.model) + response = TextCompletionResponse() + try: + while True: # loop until a non-empty string is found + # return this for all models + chunk = next(self.completion_stream) + response["id"] = chunk.get("id", None) + response["object"] = "text_completion" + response["created"] = response.get("created", None) + response["model"] = response.get("model", None) + text_choices = TextChoices() + text_choices["text"] = chunk["choices"][0]["delta"]["content"] + text_choices["index"] = response["choices"][0]["index"] + text_choices["finish_reason"] = response["choices"][0]["finish_reason"] + response["choices"] = [text_choices] + return response + except StopIteration: + raise StopIteration + except Exception as e: + print(f"got exception {e}") + async def __anext__(self): + try: + return next(self) + except StopIteration: + raise StopAsyncIteration def mock_completion_streaming_obj(model_response, mock_response, model): for i in range(0, len(mock_response), 3):