fix: fix streaming with httpx client

prevent overwriting streams in parallel streaming calls
This commit is contained in:
Krrish Dholakia 2024-05-31 10:55:18 -07:00
parent aada7b4bd3
commit 93c3635b64
9 changed files with 182 additions and 82 deletions

View file

@ -32,7 +32,7 @@ from dataclasses import (
)
import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from litellm.caching import DualCache
from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
@ -10214,8 +10214,10 @@ class CustomStreamWrapper:
custom_llm_provider=None,
logging_obj=None,
stream_options=None,
make_call: Optional[Callable] = None,
):
self.model = model
self.make_call = make_call
self.custom_llm_provider = custom_llm_provider
self.logging_obj = logging_obj
self.completion_stream = completion_stream
@ -11766,8 +11768,20 @@ class CustomStreamWrapper:
custom_llm_provider=self.custom_llm_provider,
)
async def fetch_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream
self.completion_stream = await self.make_call(
client=litellm.module_level_aclient
)
self._stream_iter = self.completion_stream.__aiter__()
return self.completion_stream
async def __anext__(self):
try:
if self.completion_stream is None:
await self.fetch_stream()
if (
self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure"