fix(utils.py): ensure last chunk is always empty delta w/ finish reason

makes sure we're openai-compatible with our streaming. Adds stricter tests for this as well
This commit is contained in:
Krrish Dholakia 2024-03-25 16:33:41 -07:00
parent f153889738
commit 9e1e97528d
3 changed files with 221 additions and 285 deletions

View file

@ -8458,6 +8458,7 @@ class CustomStreamWrapper:
self.completion_stream = completion_stream
self.sent_first_chunk = False
self.sent_last_chunk = False
self.received_finish_reason: Optional[str] = None
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = ""
self.complete_response = ""
@ -9131,7 +9132,7 @@ class CustomStreamWrapper:
"finish_reason": finish_reason,
}
def chunk_creator(self, chunk):
def model_response_creator(self):
model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None:
model_response.id = self.response_id
@ -9141,6 +9142,20 @@ class CustomStreamWrapper:
model_response._hidden_params["created_at"] = time.time()
model_response.choices = [StreamingChoices()]
model_response.choices[0].finish_reason = None
return model_response
def is_delta_empty(self, delta: Delta) -> bool:
is_empty = True
if delta.content is not None:
is_empty = False
elif delta.tool_calls is not None:
is_empty = False
elif delta.function_call is not None:
is_empty = False
return is_empty
def chunk_creator(self, chunk):
model_response = self.model_response_creator()
response_obj = {}
try:
# return this for all models
@ -9149,30 +9164,22 @@ class CustomStreamWrapper:
response_obj = self.handle_anthropic_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "together_ai":
response_obj = self.handle_together_ai_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif (
self.custom_llm_provider and self.custom_llm_provider == "baseten"
): # baseten doesn't provide streaming
@ -9183,16 +9190,12 @@ class CustomStreamWrapper:
response_obj = self.handle_ai21_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
response_obj = self.handle_maritalk_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
completion_obj["content"] = chunk[0].outputs[0].text
elif (
@ -9201,152 +9204,116 @@ class CustomStreamWrapper:
response_obj = self.handle_aleph_alpha_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "nlp_cloud":
try:
response_obj = self.handle_nlp_cloud_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
except Exception as e:
if self.sent_last_chunk:
if self.received_finish_reason:
raise e
else:
if self.sent_first_chunk is False:
raise Exception("An unknown error occurred with the stream")
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
self.received_finish_reason = "stop"
elif self.custom_llm_provider == "gemini":
try:
if hasattr(chunk, "parts") == True:
try:
if len(chunk.parts) > 0:
completion_obj["content"] = chunk.parts[0].text
if hasattr(chunk.parts[0], "finish_reason"):
model_response.choices[0].finish_reason = (
map_finish_reason(chunk.parts[0].finish_reason.name)
)
except:
if chunk.parts[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else:
completion_obj["content"] = str(chunk)
except StopIteration as e:
if self.sent_last_chunk:
raise e
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
if hasattr(chunk, "parts") == True:
try:
if len(chunk.parts) > 0:
completion_obj["content"] = chunk.parts[0].text
if hasattr(chunk.parts[0], "finish_reason"):
self.received_finish_reason = chunk.parts[
0
].finish_reason.name
except:
if chunk.parts[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else:
completion_obj["content"] = str(chunk)
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
try:
if hasattr(chunk, "candidates") == True:
if hasattr(chunk, "candidates") == True:
try:
try:
try:
completion_obj["content"] = chunk.text
except Exception as e:
if "Part has no text." in str(e):
## check for function calling
function_call = (
chunk.candidates[0]
.content.parts[0]
.function_call
)
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
_delta_obj = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
_streaming_response = StreamingChoices(
delta=_delta_obj
)
_model_response = ModelResponse(stream=True)
_model_response.choices = [_streaming_response]
response_obj = {"original_chunk": _model_response}
else:
raise e
if (
hasattr(chunk.candidates[0], "finish_reason")
and chunk.candidates[0].finish_reason.name
!= "FINISH_REASON_UNSPECIFIED"
): # every non-final chunk in vertex ai has this
model_response.choices[0].finish_reason = (
map_finish_reason(
chunk.candidates[0].finish_reason.name
)
)
completion_obj["content"] = chunk.text
except Exception as e:
if chunk.candidates[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
if "Part has no text." in str(e):
## check for function calling
function_call = (
chunk.candidates[0].content.parts[0].function_call
)
else:
completion_obj["content"] = str(chunk)
except StopIteration as e:
if self.sent_last_chunk:
raise e
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
_delta_obj = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
_streaming_response = StreamingChoices(delta=_delta_obj)
_model_response = ModelResponse(stream=True)
_model_response.choices = [_streaming_response]
response_obj = {"original_chunk": _model_response}
else:
raise e
if (
hasattr(chunk.candidates[0], "finish_reason")
and chunk.candidates[0].finish_reason.name
!= "FINISH_REASON_UNSPECIFIED"
): # every non-final chunk in vertex ai has this
self.received_finish_reason = chunk.candidates[
0
].finish_reason.name
except Exception as e:
if chunk.candidates[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else:
completion_obj["content"] = str(chunk)
elif self.custom_llm_provider == "cohere":
response_obj = self.handle_cohere_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "cohere_chat":
response_obj = self.handle_cohere_chat_chunk(chunk)
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
if self.sent_last_chunk:
if self.received_finish_reason is not None:
raise StopIteration
response_obj = self.handle_bedrock_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.sent_last_chunk = True
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "sagemaker":
verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.sent_last_chunk = True
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.sent_last_chunk:
if self.received_finish_reason is not None:
raise StopIteration
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
self.received_finish_reason = "stop"
chunk_size = 30
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
@ -9356,11 +9323,10 @@ class CustomStreamWrapper:
# fake streaming
response_obj = {}
if len(self.completion_stream) == 0:
if self.sent_last_chunk:
if self.received_finish_reason is not None:
raise StopIteration
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
self.received_finish_reason = "stop"
chunk_size = 30
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
@ -9371,41 +9337,31 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "ollama_chat":
response_obj = self.handle_ollama_chat_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "cloudflare":
response_obj = self.handle_cloudlfare_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "azure_text":
response_obj = self.handle_azure_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "cached_response":
response_obj = {
"text": chunk.choices[0].delta.content,
@ -9419,9 +9375,7 @@ class CustomStreamWrapper:
if hasattr(chunk, "id"):
model_response.id = chunk.id
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
else: # openai / azure chat model
if self.custom_llm_provider == "azure":
if hasattr(chunk, "model"):
@ -9437,9 +9391,7 @@ class CustomStreamWrapper:
raise Exception(
"Mistral API raised a streaming error - finish_reason: error, no content string given."
)
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.received_finish_reason = response_obj["finish_reason"]
if response_obj.get("original_chunk", None) is not None:
model_response.system_fingerprint = getattr(
response_obj["original_chunk"], "system_fingerprint", None
@ -9451,7 +9403,7 @@ class CustomStreamWrapper:
model_response.model = self.model
print_verbose(
f"model_response finish reason 3: {model_response.choices[0].finish_reason}; response_obj={response_obj}"
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
)
## FUNCTION CALL PARSING
if (
@ -9580,7 +9532,7 @@ class CustomStreamWrapper:
return model_response
else:
return
elif model_response.choices[0].finish_reason is not None:
elif self.received_finish_reason is not None:
# flush any remaining holding chunk
if len(self.holding_chunk) > 0:
if model_response.choices[0].delta.content is None:
@ -9590,10 +9542,17 @@ class CustomStreamWrapper:
self.holding_chunk + model_response.choices[0].delta.content
)
self.holding_chunk = ""
# get any function call arguments
model_response.choices[0].finish_reason = map_finish_reason(
model_response.choices[0].finish_reason
) # ensure consistent output to openai
# if delta is None
is_delta_empty = self.is_delta_empty(
delta=model_response.choices[0].delta
)
if is_delta_empty:
# get any function call arguments
model_response.choices[0].finish_reason = map_finish_reason(
finish_reason=self.received_finish_reason
) # ensure consistent output to openai
self.sent_last_chunk = True
return model_response
elif (
model_response.choices[0].delta.tool_calls is not None
@ -9653,6 +9612,16 @@ class CustomStreamWrapper:
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk)
def finish_reason_handler(self):
model_response = self.model_response_creator()
if self.received_finish_reason is not None:
model_response.choices[0].finish_reason = map_finish_reason(
finish_reason=self.received_finish_reason
)
else:
model_response.choices[0].finish_reason = "stop"
return model_response
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
try:
@ -9687,7 +9656,11 @@ class CustomStreamWrapper:
# RETURN RESULT
return response
except StopIteration:
raise # Re-raise StopIteration
if self.sent_last_chunk == True:
raise # Re-raise StopIteration
else:
self.sent_last_chunk = True
return self.finish_reason_handler()
except Exception as e:
traceback_exception = traceback.format_exc()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
@ -9792,9 +9765,17 @@ class CustomStreamWrapper:
# RETURN RESULT
return processed_chunk
except StopAsyncIteration:
raise
if self.sent_last_chunk == True:
raise # Re-raise StopIteration
else:
self.sent_last_chunk = True
return self.finish_reason_handler()
except StopIteration:
raise StopAsyncIteration # Re-raise StopIteration
if self.sent_last_chunk == True:
raise StopAsyncIteration
else:
self.sent_last_chunk = True
return self.finish_reason_handler()
except Exception as e:
traceback_exception = traceback.format_exc()
# Handle any exceptions that might occur during streaming