fix: fix streaming with httpx client

prevent overwriting streams in parallel streaming calls
This commit is contained in:
Krrish Dholakia 2024-05-31 10:55:18 -07:00
parent aada7b4bd3
commit 93c3635b64
9 changed files with 182 additions and 82 deletions

View file

@ -1385,8 +1385,21 @@ def test_bedrock_claude_3_streaming():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"model",
[
"claude-3-opus-20240229",
"cohere.command-r-plus-v1:0", # bedrock
"databricks/databricks-dbrx-instruct", # databricks
"predibase/llama-3-8b-instruct", # predibase
"replicate/meta/meta-llama-3-8b-instruct", # replicate
],
)
@pytest.mark.asyncio
async def test_claude_3_streaming_finish_reason(sync_mode):
async def test_parallel_streaming_requests(sync_mode, model):
"""
Important prod test.
"""
try:
import threading
@ -1398,7 +1411,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
def sync_test_streaming():
response: litellm.CustomStreamWrapper = litellm.acompletion( # type: ignore
model="claude-3-opus-20240229",
model=model,
messages=messages,
stream=True,
max_tokens=10,
@ -1415,7 +1428,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
async def test_streaming():
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model="claude-3-opus-20240229",
model=model,
messages=messages,
stream=True,
max_tokens=10,
@ -1424,8 +1437,9 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
# Add any assertions here to-check the response
num_finish_reason = 0
async for chunk in response:
print(f"chunk: {chunk}")
print(f"type of chunk: {type(chunk)}")
if isinstance(chunk, ModelResponse):
print(f"OUTSIDE CHUNK: {chunk.choices[0]}")
if chunk.choices[0].finish_reason is not None:
num_finish_reason += 1
assert num_finish_reason == 1