fix(fixes-for-text-completion-streaming): fixes for text completion streaming

This commit is contained in:
Krrish Dholakia 2024-01-08 13:39:54 +05:30
parent 39fb3f2a74
commit ff12e023ae
2 changed files with 38 additions and 18 deletions

View file

@ -469,6 +469,7 @@ def completion(
"caching_groups", "caching_groups",
"ttl", "ttl",
"cache", "cache",
"parent_call"
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
@ -2619,7 +2620,7 @@ def text_completion(
# only use engine when model not passed # only use engine when model not passed
model = kwargs["engine"] model = kwargs["engine"]
kwargs.pop("engine") kwargs.pop("engine")
kwargs["parent_call"] = kwargs.get("parent_call", "text_completion")
text_completion_response = TextCompletionResponse() text_completion_response = TextCompletionResponse()
optional_params: Dict[str, Any] = {} optional_params: Dict[str, Any] = {}
@ -2726,6 +2727,7 @@ def text_completion(
if kwargs.get("acompletion", False) == True: if kwargs.get("acompletion", False) == True:
return response return response
if stream == True or kwargs.get("stream", False) == True: if stream == True or kwargs.get("stream", False) == True:
print(f"original model response: {response}")
response = TextCompletionStreamWrapper(completion_stream=response, model=model) response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response return response
transformed_logprobs = None transformed_logprobs = None
@ -3162,22 +3164,23 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
else: else:
completion_output = "" completion_output = ""
# # Update usage information if needed # # Update usage information if needed
try: print(f"INSIDE TEXT COMPLETION STREAM CHUNK BUILDER")
response["usage"]["prompt_tokens"] = token_counter( _usage = litellm.Usage
model=model, messages=messages print(f"messages: {messages}")
_usage.prompt_tokens = token_counter(
model=model, messages=messages, count_response_tokens=True
) )
except: # don't allow this failing to block a complete streaming response from being returned print(f"received prompt tokens: {_usage.prompt_tokens}")
print_verbose(f"token_counter failed, assuming prompt tokens is 0") _usage.completion_tokens = token_counter(
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = token_counter(
model=model, model=model,
text=combined_content, text=combined_content,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
) )
response["usage"]["total_tokens"] = ( _usage.total_tokens = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] _usage.prompt_tokens + _usage.completion_tokens
) )
return response response["usage"] = _usage
return litellm.TextCompletionResponse(**response)
def stream_chunk_builder(chunks: list, messages: Optional[list] = None): def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
id = chunks[0]["id"] id = chunks[0]["id"]

View file

@ -2038,12 +2038,19 @@ def client(original_function):
if ( if (
"complete_response" in kwargs "complete_response" in kwargs
and kwargs["complete_response"] == True and kwargs["complete_response"] == True
and kwargs.get("parent_call", None) is None
): ):
chunks = [] chunks = []
for idx, chunk in enumerate(result): for idx, chunk in enumerate(result):
chunks.append(chunk) chunks.append(chunk)
call_type = original_function.__name__
if call_type == CallTypes.completion.value:
return litellm.stream_chunk_builder( return litellm.stream_chunk_builder(
chunks, messages=kwargs.get("messages", None) chunks, messages=kwargs.get("messages")
)
elif call_type == CallTypes.text_completion.value:
return litellm.stream_chunk_builder(
chunks, messages=[{"role": "user", "content": kwargs.get("prompt")}]
) )
else: else:
return result return result
@ -2695,9 +2702,10 @@ def token_counter(
raise ValueError("text and messages cannot both be None") raise ValueError("text and messages cannot both be None")
elif isinstance(text, List): elif isinstance(text, List):
text = "".join(t for t in text if isinstance(t, str)) text = "".join(t for t in text if isinstance(t, str))
print_verbose(f"text: {text}")
if model is not None: if model is not None:
tokenizer_json = _select_tokenizer(model=model) tokenizer_json = _select_tokenizer(model=model)
print(f"tokenizer_json['type']: {tokenizer_json['type']}")
if tokenizer_json["type"] == "huggingface_tokenizer": if tokenizer_json["type"] == "huggingface_tokenizer":
print_verbose( print_verbose(
f"Token Counter - using hugging face token counter, for model={model}" f"Token Counter - using hugging face token counter, for model={model}"
@ -2731,6 +2739,7 @@ def token_counter(
num_tokens = len(enc) num_tokens = len(enc)
else: else:
num_tokens = len(encoding.encode(text)) # type: ignore num_tokens = len(encoding.encode(text)) # type: ignore
print_verbose(f"final num tokens returned: {num_tokens}")
return num_tokens return num_tokens
@ -7760,6 +7769,8 @@ class TextCompletionStreamWrapper:
def convert_to_text_completion_object(self, chunk: ModelResponse): def convert_to_text_completion_object(self, chunk: ModelResponse):
try: try:
if not isinstance(chunk, ModelResponse):
return
response = TextCompletionResponse() response = TextCompletionResponse()
response["id"] = chunk.get("id", None) response["id"] = chunk.get("id", None)
response["object"] = "text_completion" response["object"] = "text_completion"
@ -7784,12 +7795,18 @@ class TextCompletionStreamWrapper:
# model_response = ModelResponse(stream=True, model=self.model) # model_response = ModelResponse(stream=True, model=self.model)
response = TextCompletionResponse() response = TextCompletionResponse()
try: try:
for chunk in self.completion_stream: while True:
if isinstance(self.completion_stream, str) or isinstance(
self.completion_stream, bytes
) or isinstance(self.completion_stream, ModelResponse):
chunk = self.completion_stream
else:
chunk = next(self.completion_stream)
if chunk == "None" or chunk is None: if chunk == "None" or chunk is None:
raise Exception raise Exception
processed_chunk = self.convert_to_text_completion_object(chunk=chunk) processed_chunk = self.convert_to_text_completion_object(chunk=chunk)
return processed_chunk return processed_chunk
raise StopIteration
except StopIteration: except StopIteration:
raise StopIteration raise StopIteration
except Exception as e: except Exception as e: