MockResponsesAPIStreamingIterator

This commit is contained in:
Ishaan Jaff 2025-03-20 12:30:09 -07:00
parent 55115bf520
commit a29587e178

View file

@ -20,6 +20,7 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.responses.streaming_iterator import (
BaseResponsesAPIStreamingIterator,
MockResponsesAPIStreamingIterator,
ResponsesAPIStreamingIterator,
SyncResponsesAPIStreamingIterator,
)
@ -1004,6 +1005,7 @@ class BaseLLMHTTPHandler:
extra_body=extra_body,
timeout=timeout,
client=client if isinstance(client, AsyncHTTPHandler) else None,
fake_stream=fake_stream,
)
if client is None or not isinstance(client, HTTPHandler):
@ -1052,13 +1054,26 @@ class BaseLLMHTTPHandler:
try:
if stream:
# For streaming, use stream=True in the request
if fake_stream is True:
stream, data = self._prepare_fake_stream_request(
stream=stream,
data=data,
fake_stream=fake_stream,
)
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout
or response_api_optional_request_params.get("timeout"),
stream=True,
stream=stream,
)
if fake_stream is True:
return MockResponsesAPIStreamingIterator(
response=response,
model=model,
logging_obj=logging_obj,
responses_api_provider_config=responses_api_provider_config,
)
return SyncResponsesAPIStreamingIterator(
@ -1147,20 +1162,34 @@ class BaseLLMHTTPHandler:
"headers": headers,
},
)
# Check if streaming is requested
stream = response_api_optional_request_params.get("stream", False)
try:
if stream:
# For streaming, we need to use stream=True in the request
if fake_stream is True:
stream, data = self._prepare_fake_stream_request(
stream=stream,
data=data,
fake_stream=fake_stream,
)
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout
or response_api_optional_request_params.get("timeout"),
stream=True,
stream=stream,
)
if fake_stream is True:
return MockResponsesAPIStreamingIterator(
response=response,
model=model,
logging_obj=logging_obj,
responses_api_provider_config=responses_api_provider_config,
)
# Return the streaming iterator
@ -1179,6 +1208,7 @@ class BaseLLMHTTPHandler:
timeout=timeout
or response_api_optional_request_params.get("timeout"),
)
except Exception as e:
raise self._handle_error(
e=e,
@ -1191,6 +1221,21 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
)
def _prepare_fake_stream_request(
self,
stream: bool,
data: dict,
fake_stream: bool,
) -> Tuple[bool, dict]:
"""
Handles preparing a request when `fake_stream` is True.
"""
if fake_stream is True:
stream = False
data.pop("stream", None)
return stream, data
return stream, data
def _handle_error(
self,
e: Exception,