mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(feat) add streaming for text_completion
This commit is contained in:
parent
a404b0fc3b
commit
2a751c277f
2 changed files with 41 additions and 0 deletions
|
@ -59,6 +59,7 @@ encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
get_secret,
|
get_secret,
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
TextCompletionStreamWrapper,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
TextChoices,
|
TextChoices,
|
||||||
|
@ -2031,6 +2032,9 @@ def text_completion(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
|
if stream == True or kwargs.get("stream", False) == True:
|
||||||
|
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
||||||
|
return response
|
||||||
|
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
# only supported for TGI models
|
# only supported for TGI models
|
||||||
|
|
|
@ -4156,6 +4156,43 @@ class CustomStreamWrapper:
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
class TextCompletionStreamWrapper:
|
||||||
|
def __init__(self, completion_stream, model):
|
||||||
|
self.completion_stream = completion_stream
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
# model_response = ModelResponse(stream=True, model=self.model)
|
||||||
|
response = TextCompletionResponse()
|
||||||
|
try:
|
||||||
|
while True: # loop until a non-empty string is found
|
||||||
|
# return this for all models
|
||||||
|
chunk = next(self.completion_stream)
|
||||||
|
response["id"] = chunk.get("id", None)
|
||||||
|
response["object"] = "text_completion"
|
||||||
|
response["created"] = response.get("created", None)
|
||||||
|
response["model"] = response.get("model", None)
|
||||||
|
text_choices = TextChoices()
|
||||||
|
text_choices["text"] = chunk["choices"][0]["delta"]["content"]
|
||||||
|
text_choices["index"] = response["choices"][0]["index"]
|
||||||
|
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
|
||||||
|
response["choices"] = [text_choices]
|
||||||
|
return response
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except Exception as e:
|
||||||
|
print(f"got exception {e}")
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
return next(self)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
def mock_completion_streaming_obj(model_response, mock_response, model):
|
def mock_completion_streaming_obj(model_response, mock_response, model):
|
||||||
for i in range(0, len(mock_response), 3):
|
for i in range(0, len(mock_response), 3):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue