(fix) streaming + function / tool calling

This commit is contained in:
ishaan-jaff 2023-11-18 16:23:05 -08:00
parent c02794d3ff
commit 70fc5afb5d

View file

@ -4598,13 +4598,23 @@ class CustomStreamWrapper:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = None finish_reason = None
original_chunk = None # this is used for function/tool calling
if len(str_line.choices) > 0: if len(str_line.choices) > 0:
if str_line.choices[0].delta.content is not None: if str_line.choices[0].delta.content is not None:
text = str_line.choices[0].delta.content text = str_line.choices[0].delta.content
else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
original_chunk = str_line
if str_line.choices[0].finish_reason: if str_line.choices[0].finish_reason:
is_finished = True is_finished = True
finish_reason = str_line.choices[0].finish_reason finish_reason = str_line.choices[0].finish_reason
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"original_chunk": str_line
}
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise e raise e
@ -4856,6 +4866,13 @@ class CustomStreamWrapper:
return model_response return model_response
else: else:
return return
elif response_obj.get("original_chunk", None) is not None: # function / tool calling branch
model_response = response_obj.get("original_chunk", None)
if self.sent_first_chunk == False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response
return model_response
elif model_response.choices[0].finish_reason: elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING # LOGGING