fix(utils.py): flush holding chunk for streaming, on stream end

This commit is contained in:
Krrish Dholakia 2023-12-12 16:13:31 -08:00
parent 7efcc550e2
commit 669862643b
2 changed files with 27 additions and 2 deletions

View file

@ -245,7 +245,6 @@ def test_completion_azure_stream():
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response
for idx, init_chunk in enumerate(response): for idx, init_chunk in enumerate(response):
print(f"azure chunk: {init_chunk}")
chunk, finished = streaming_format_tests(idx, init_chunk) chunk, finished = streaming_format_tests(idx, init_chunk)
complete_response += chunk complete_response += chunk
if finished: if finished:
@ -255,7 +254,7 @@ def test_completion_azure_stream():
raise Exception("Empty response received") raise Exception("Empty response received")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_azure_stream() # test_completion_azure_stream()
def test_completion_azure_function_calling_stream(): def test_completion_azure_function_calling_stream():
try: try:
@ -636,6 +635,25 @@ def test_completion_bedrock_ai21_stream():
# test_completion_bedrock_ai21_stream() # test_completion_bedrock_ai21_stream()
def test_sagemaker_weird_response():
"""
When the stream ends, flush any remaining holding chunks.
"""
try:
chunk = """<s>[INST] Hey, how's it going? [/INST]
I'm doing well, thanks for asking! How about you? Is there anything you'd like to chat about or ask? I'm here to help with any questions you might have."""
logging_obj = litellm.Logging(model="berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, stream=True, litellm_call_id="1234", function_id="function_id", call_type="acompletion", start_time=time.time())
response = litellm.CustomStreamWrapper(completion_stream=chunk, model="berri-benchmarking-Llama-2-70b-chat-hf-4", custom_llm_provider="sagemaker", logging_obj=logging_obj)
complete_response = ""
for chunk in response:
complete_response += chunk["choices"][0]["delta"]["content"]
assert len(complete_response) > 0
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# test_sagemaker_weird_response()
# def test_completion_sagemaker_stream(): # def test_completion_sagemaker_stream():
# try: # try:
# response = completion( # response = completion(

View file

@ -5682,6 +5682,13 @@ class CustomStreamWrapper:
else: else:
return return
elif model_response.choices[0].finish_reason: elif model_response.choices[0].finish_reason:
# flush any remaining holding chunk
if len(self.holding_chunk) > 0:
if model_response.choices[0].delta.content is None:
model_response.choices[0].delta.content = self.holding_chunk
else:
model_response.choices[0].delta.content = self.holding_chunk + model_response.choices[0].delta.content
self.holding_chunk = ""
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
return model_response return model_response
elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints