(fix) asyc callback + stream-stop dbl cnt chunk

This commit is contained in:
ishaan-jaff 2023-12-08 17:22:13 -08:00
parent e237361891
commit 88c1d6649f

View file

@ -5304,7 +5304,7 @@ class CustomStreamWrapper:
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
return ""
def chunk_creator(self, chunk):
def chunk_creator(self, chunk, in_async_func=False):
model_response = ModelResponse(stream=True, model=self.model)
model_response.choices[0].finish_reason = None
response_obj = {}
@ -5481,7 +5481,10 @@ class CustomStreamWrapper:
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
if in_async_func != True:
# only do logging if we're not being called by _anext_
# _anext_ does its own logging, we check to avoid double counting chunks
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
print_verbose(f"model_response: {model_response}")
return model_response
else:
@ -5489,7 +5492,8 @@ class CustomStreamWrapper:
elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
if in_async_func != True:
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints
# enter this branch when no content has been passed in response
@ -5511,7 +5515,8 @@ class CustomStreamWrapper:
model_response.choices[0].delta["role"] = "assistant"
self.sent_first_chunk = True
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response
if in_async_func != True:
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response
return model_response
else:
return
@ -5554,7 +5559,10 @@ class CustomStreamWrapper:
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = self.chunk_creator(chunk=chunk)
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
# __anext__ also calls async_success_handler, which does logging
processed_chunk = self.chunk_creator(chunk=chunk, in_async_func=True)
if processed_chunk is None:
continue
## LOGGING