fix(fixes-for-text-completion-streaming): fixes for text completion streaming

This commit is contained in:
Krrish Dholakia 2024-01-08 13:39:54 +05:30
parent 39fb3f2a74
commit ff12e023ae
2 changed files with 38 additions and 18 deletions

View file

@ -469,6 +469,7 @@ def completion(
"caching_groups",
"ttl",
"cache",
"parent_call"
]
default_params = openai_params + litellm_params
non_default_params = {
@ -2619,7 +2620,7 @@ def text_completion(
# only use engine when model not passed
model = kwargs["engine"]
kwargs.pop("engine")
kwargs["parent_call"] = kwargs.get("parent_call", "text_completion")
text_completion_response = TextCompletionResponse()
optional_params: Dict[str, Any] = {}
@ -2726,6 +2727,7 @@ def text_completion(
if kwargs.get("acompletion", False) == True:
return response
if stream == True or kwargs.get("stream", False) == True:
print(f"original model response: {response}")
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response
transformed_logprobs = None
@ -3162,22 +3164,23 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
else:
completion_output = ""
# # Update usage information if needed
try:
response["usage"]["prompt_tokens"] = token_counter(
model=model, messages=messages
)
except: # don't allow this failing to block a complete streaming response from being returned
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = token_counter(
print(f"INSIDE TEXT COMPLETION STREAM CHUNK BUILDER")
_usage = litellm.Usage
print(f"messages: {messages}")
_usage.prompt_tokens = token_counter(
model=model, messages=messages, count_response_tokens=True
)
print(f"received prompt tokens: {_usage.prompt_tokens}")
_usage.completion_tokens = token_counter(
model=model,
text=combined_content,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
)
response["usage"]["total_tokens"] = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
_usage.total_tokens = (
_usage.prompt_tokens + _usage.completion_tokens
)
return response
response["usage"] = _usage
return litellm.TextCompletionResponse(**response)
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
id = chunks[0]["id"]