mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(sagemaker.py): support streaming for messages api
Fixes https://github.com/BerriAI/litellm/issues/5372
This commit is contained in:
parent
73f8315a77
commit
b989762bb0
8 changed files with 142 additions and 32 deletions
|
@ -120,15 +120,24 @@ async def test_completion_sagemaker_messages_api(sync_mode):
|
|||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||
async def test_completion_sagemaker_stream(sync_mode):
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
||||
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
],
|
||||
)
|
||||
async def test_completion_sagemaker_stream(sync_mode, model):
|
||||
try:
|
||||
from litellm.tests.test_streaming import streaming_format_tests
|
||||
|
||||
litellm.set_verbose = False
|
||||
print("testing sagemaker")
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
full_text = ""
|
||||
if sync_mode is True:
|
||||
response = litellm.completion(
|
||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": "hi - what is ur name"},
|
||||
],
|
||||
|
@ -138,14 +147,15 @@ async def test_completion_sagemaker_stream(sync_mode):
|
|||
input_cost_per_second=0.000420,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
for idx, chunk in enumerate(response):
|
||||
print(chunk)
|
||||
streaming_format_tests(idx=idx, chunk=chunk)
|
||||
full_text += chunk.choices[0].delta.content or ""
|
||||
|
||||
print("SYNC RESPONSE full text", full_text)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": "hi - what is ur name"},
|
||||
],
|
||||
|
@ -156,10 +166,12 @@ async def test_completion_sagemaker_stream(sync_mode):
|
|||
)
|
||||
|
||||
print("streaming response")
|
||||
|
||||
idx = 0
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
streaming_format_tests(idx=idx, chunk=chunk)
|
||||
full_text += chunk.choices[0].delta.content or ""
|
||||
idx += 1
|
||||
|
||||
print("ASYNC RESPONSE full text", full_text)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue