fix(sagemaker.py): support streaming for messages api

Fixes https://github.com/BerriAI/litellm/issues/5372
This commit is contained in:
Krrish Dholakia 2024-08-26 15:08:08 -07:00
parent 73f8315a77
commit b989762bb0
8 changed files with 142 additions and 32 deletions

View file

@ -120,15 +120,24 @@ async def test_completion_sagemaker_messages_api(sync_mode):
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [False, True])
async def test_completion_sagemaker_stream(sync_mode):
@pytest.mark.parametrize(
"model",
[
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
],
)
async def test_completion_sagemaker_stream(sync_mode, model):
try:
from litellm.tests.test_streaming import streaming_format_tests
litellm.set_verbose = False
print("testing sagemaker")
verbose_logger.setLevel(logging.DEBUG)
full_text = ""
if sync_mode is True:
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
model=model,
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
@ -138,14 +147,15 @@ async def test_completion_sagemaker_stream(sync_mode):
input_cost_per_second=0.000420,
)
for chunk in response:
for idx, chunk in enumerate(response):
print(chunk)
streaming_format_tests(idx=idx, chunk=chunk)
full_text += chunk.choices[0].delta.content or ""
print("SYNC RESPONSE full text", full_text)
else:
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
model=model,
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
@ -156,10 +166,12 @@ async def test_completion_sagemaker_stream(sync_mode):
)
print("streaming response")
idx = 0
async for chunk in response:
print(chunk)
streaming_format_tests(idx=idx, chunk=chunk)
full_text += chunk.choices[0].delta.content or ""
idx += 1
print("ASYNC RESPONSE full text", full_text)