(test) fix sagemaker stream test

This commit is contained in:
ishaan-jaff 2024-01-23 10:07:13 -08:00
parent 9aa40c63ee
commit e2e56c03f4

View file

@ -274,7 +274,7 @@ def test_completion_azure_stream():
pytest.fail(f"Error occurred: {e}")
test_completion_azure_stream()
# test_completion_azure_stream()
def test_completion_azure_function_calling_stream():
@ -799,10 +799,31 @@ def test_sagemaker_weird_response():
When the stream ends, flush any remaining holding chunks.
"""
try:
chunk = """<s>[INST] Hey, how's it going? [/INST]
from litellm.llms.sagemaker import TokenIterator
import json
import json
from litellm.llms.sagemaker import TokenIterator
chunk = """<s>[INST] Hey, how's it going? [/INST],
I'm doing well, thanks for asking! How about you? Is there anything you'd like to chat about or ask? I'm here to help with any questions you might have."""
data = "\n".join(
map(
lambda x: f"data: {json.dumps({'token': {'text': x.strip()}})}",
chunk.strip().split(","),
)
)
stream = bytes(data, encoding="utf8")
# Modify the array to be a dictionary with "PayloadPart" and "Bytes" keys.
stream_iterator = iter([{"PayloadPart": {"Bytes": stream}}])
token_iter = TokenIterator(stream_iterator)
# for token in token_iter:
# print(token)
litellm.set_verbose = True
logging_obj = litellm.Logging(
model="berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=messages,
@ -813,13 +834,14 @@ def test_sagemaker_weird_response():
start_time=time.time(),
)
response = litellm.CustomStreamWrapper(
completion_stream=chunk,
completion_stream=token_iter,
model="berri-benchmarking-Llama-2-70b-chat-hf-4",
custom_llm_provider="sagemaker",
logging_obj=logging_obj,
)
complete_response = ""
for chunk in response:
print(chunk)
complete_response += chunk["choices"][0]["delta"]["content"]
assert len(complete_response) > 0
except Exception as e: