(feat) add streaming for text_completion

This commit is contained in:
ishaan-jaff 2023-11-08 11:58:07 -08:00
parent a404b0fc3b
commit 2a751c277f
2 changed files with 41 additions and 0 deletions

View file

@ -59,6 +59,7 @@ encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
get_secret,
CustomStreamWrapper,
TextCompletionStreamWrapper,
ModelResponse,
TextCompletionResponse,
TextChoices,
@ -2031,6 +2032,9 @@ def text_completion(
**kwargs,
**optional_params,
)
if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response
transformed_logprobs = None
# only supported for TGI models

View file

@ -4156,6 +4156,43 @@ class CustomStreamWrapper:
except StopIteration:
raise StopAsyncIteration
class TextCompletionStreamWrapper:
def __init__(self, completion_stream, model):
self.completion_stream = completion_stream
self.model = model
def __iter__(self):
return self
def __aiter__(self):
return self
def __next__(self):
# model_response = ModelResponse(stream=True, model=self.model)
response = TextCompletionResponse()
try:
while True: # loop until a non-empty string is found
# return this for all models
chunk = next(self.completion_stream)
response["id"] = chunk.get("id", None)
response["object"] = "text_completion"
response["created"] = response.get("created", None)
response["model"] = response.get("model", None)
text_choices = TextChoices()
text_choices["text"] = chunk["choices"][0]["delta"]["content"]
text_choices["index"] = response["choices"][0]["index"]
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
response["choices"] = [text_choices]
return response
except StopIteration:
raise StopIteration
except Exception as e:
print(f"got exception {e}")
async def __anext__(self):
try:
return next(self)
except StopIteration:
raise StopAsyncIteration
def mock_completion_streaming_obj(model_response, mock_response, model):
for i in range(0, len(mock_response), 3):