Merge pull request #3944 from BerriAI/litellm_fix_parallel_streaming

fix: fix streaming with httpx client
This commit is contained in:
Krish Dholakia 2024-05-31 21:42:37 -07:00 committed by GitHub
commit e7ff3adc26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 182 additions and 82 deletions

View file

@ -32,7 +32,7 @@ from dataclasses import (
)
import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from litellm.caching import DualCache
from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
@ -10217,8 +10217,10 @@ class CustomStreamWrapper:
custom_llm_provider=None,
logging_obj=None,
stream_options=None,
make_call: Optional[Callable] = None,
):
self.model = model
self.make_call = make_call
self.custom_llm_provider = custom_llm_provider
self.logging_obj = logging_obj
self.completion_stream = completion_stream
@ -11769,8 +11771,20 @@ class CustomStreamWrapper:
custom_llm_provider=self.custom_llm_provider,
)
async def fetch_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream
self.completion_stream = await self.make_call(
client=litellm.module_level_aclient
)
self._stream_iter = self.completion_stream.__aiter__()
return self.completion_stream
async def __anext__(self):
try:
if self.completion_stream is None:
await self.fetch_stream()
if (
self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure"