(test) fix sagemaker stream test

This commit is contained in:
ishaan-jaff 2024-01-23 10:07:13 -08:00
parent 9aa40c63ee
commit e2e56c03f4

View file

@ -274,7 +274,7 @@ def test_completion_azure_stream():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_azure_stream() # test_completion_azure_stream()
def test_completion_azure_function_calling_stream(): def test_completion_azure_function_calling_stream():
@ -799,9 +799,30 @@ def test_sagemaker_weird_response():
When the stream ends, flush any remaining holding chunks. When the stream ends, flush any remaining holding chunks.
""" """
try: try:
chunk = """<s>[INST] Hey, how's it going? [/INST] from litellm.llms.sagemaker import TokenIterator
import json
import json
from litellm.llms.sagemaker import TokenIterator
I'm doing well, thanks for asking! How about you? Is there anything you'd like to chat about or ask? I'm here to help with any questions you might have.""" chunk = """<s>[INST] Hey, how's it going? [/INST],
I'm doing well, thanks for asking! How about you? Is there anything you'd like to chat about or ask? I'm here to help with any questions you might have."""
data = "\n".join(
map(
lambda x: f"data: {json.dumps({'token': {'text': x.strip()}})}",
chunk.strip().split(","),
)
)
stream = bytes(data, encoding="utf8")
# Modify the array to be a dictionary with "PayloadPart" and "Bytes" keys.
stream_iterator = iter([{"PayloadPart": {"Bytes": stream}}])
token_iter = TokenIterator(stream_iterator)
# for token in token_iter:
# print(token)
litellm.set_verbose = True
logging_obj = litellm.Logging( logging_obj = litellm.Logging(
model="berri-benchmarking-Llama-2-70b-chat-hf-4", model="berri-benchmarking-Llama-2-70b-chat-hf-4",
@ -813,13 +834,14 @@ def test_sagemaker_weird_response():
start_time=time.time(), start_time=time.time(),
) )
response = litellm.CustomStreamWrapper( response = litellm.CustomStreamWrapper(
completion_stream=chunk, completion_stream=token_iter,
model="berri-benchmarking-Llama-2-70b-chat-hf-4", model="berri-benchmarking-Llama-2-70b-chat-hf-4",
custom_llm_provider="sagemaker", custom_llm_provider="sagemaker",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
complete_response = "" complete_response = ""
for chunk in response: for chunk in response:
print(chunk)
complete_response += chunk["choices"][0]["delta"]["content"] complete_response += chunk["choices"][0]["delta"]["content"]
assert len(complete_response) > 0 assert len(complete_response) > 0
except Exception as e: except Exception as e: