mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
add SyncResponsesAPIStreamingIterator
This commit is contained in:
parent
3bf2fda128
commit
e4cda0a1b7
1 changed files with 136 additions and 50 deletions
|
@ -17,12 +17,11 @@ from litellm.utils import CustomStreamWrapper
|
|||
COMPLETED_OPENAI_CHUNK_TYPE = "response.completed"
|
||||
|
||||
|
||||
class ResponsesAPIStreamingIterator:
|
||||
class BaseResponsesAPIStreamingIterator:
|
||||
"""
|
||||
Async iterator for processing streaming responses from the Responses API.
|
||||
Base class for streaming iterators that process responses from the Responses API.
|
||||
|
||||
This iterator handles the chunked streaming format returned by the Responses API
|
||||
and yields properly formatted ResponsesAPIStreamingResponse objects.
|
||||
This class contains shared logic for both synchronous and asynchronous iterators.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -35,12 +34,76 @@ class ResponsesAPIStreamingIterator:
|
|||
self.response = response
|
||||
self.model = model
|
||||
self.logging_obj = logging_obj
|
||||
self.stream_iterator = response.aiter_lines()
|
||||
self.finished = False
|
||||
self.responses_api_provider_config = responses_api_provider_config
|
||||
self.completed_response: Optional[ResponsesAPIStreamingResponse] = None
|
||||
self.completed_response = None
|
||||
self.start_time = datetime.now()
|
||||
|
||||
def _process_chunk(self, chunk):
|
||||
"""Process a single chunk of data from the stream"""
|
||||
if not chunk:
|
||||
return None
|
||||
|
||||
# Handle SSE format (data: {...})
|
||||
chunk = CustomStreamWrapper._strip_sse_data_from_chunk(chunk)
|
||||
if chunk is None:
|
||||
return None
|
||||
|
||||
# Handle "[DONE]" marker
|
||||
if chunk == "[DONE]":
|
||||
self.finished = True
|
||||
return None
|
||||
|
||||
try:
|
||||
# Parse the JSON chunk
|
||||
parsed_chunk = json.loads(chunk)
|
||||
|
||||
# Format as ResponsesAPIStreamingResponse
|
||||
if isinstance(parsed_chunk, dict):
|
||||
openai_responses_api_chunk = (
|
||||
self.responses_api_provider_config.transform_streaming_response(
|
||||
model=self.model,
|
||||
parsed_chunk=parsed_chunk,
|
||||
logging_obj=self.logging_obj,
|
||||
)
|
||||
)
|
||||
# Store the completed response
|
||||
if (
|
||||
openai_responses_api_chunk
|
||||
and openai_responses_api_chunk.type
|
||||
== ResponsesAPIStreamEvents.RESPONSE_COMPLETED
|
||||
):
|
||||
self.completed_response = openai_responses_api_chunk
|
||||
self._handle_completed_response()
|
||||
|
||||
return openai_responses_api_chunk
|
||||
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
# If we can't parse the chunk, continue
|
||||
return None
|
||||
|
||||
def _handle_completed_response(self):
|
||||
"""Base implementation - should be overridden by subclasses"""
|
||||
pass
|
||||
|
||||
|
||||
class ResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||
"""
|
||||
Async iterator for processing streaming responses from the Responses API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
responses_api_provider_config: BaseResponsesAPIConfig,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
):
|
||||
super().__init__(response, model, responses_api_provider_config, logging_obj)
|
||||
self.stream_iterator = response.aiter_lines()
|
||||
self.completed_response: Optional[ResponsesAPIStreamingResponse] = None
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
|
@ -53,55 +116,78 @@ class ResponsesAPIStreamingIterator:
|
|||
self.finished = True
|
||||
raise StopAsyncIteration
|
||||
|
||||
if not chunk:
|
||||
return await self.__anext__()
|
||||
result = self._process_chunk(chunk)
|
||||
|
||||
# Handle SSE format (data: {...})
|
||||
chunk = CustomStreamWrapper._strip_sse_data_from_chunk(chunk)
|
||||
if chunk is None:
|
||||
return await self.__anext__()
|
||||
|
||||
# Handle "[DONE]" marker
|
||||
if chunk == "[DONE]":
|
||||
self.finished = True
|
||||
if self.finished:
|
||||
raise StopAsyncIteration
|
||||
|
||||
try:
|
||||
# Parse the JSON chunk
|
||||
parsed_chunk = json.loads(chunk)
|
||||
|
||||
# Format as ResponsesAPIStreamingResponse
|
||||
if isinstance(parsed_chunk, dict):
|
||||
openai_responses_api_chunk: ResponsesAPIStreamingResponse = (
|
||||
self.responses_api_provider_config.transform_streaming_response(
|
||||
model=self.model,
|
||||
parsed_chunk=parsed_chunk,
|
||||
logging_obj=self.logging_obj,
|
||||
)
|
||||
)
|
||||
# Store the completed response
|
||||
if (
|
||||
openai_responses_api_chunk
|
||||
and openai_responses_api_chunk.type
|
||||
== ResponsesAPIStreamEvents.RESPONSE_COMPLETED
|
||||
):
|
||||
self.completed_response = openai_responses_api_chunk
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
result=self.completed_response,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.now(),
|
||||
cache_hit=None,
|
||||
)
|
||||
)
|
||||
return openai_responses_api_chunk
|
||||
|
||||
return await self.__anext__()
|
||||
except json.JSONDecodeError:
|
||||
# If we can't parse the chunk, continue to the next one
|
||||
elif result is not None:
|
||||
return result
|
||||
else:
|
||||
return await self.__anext__()
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
# Handle HTTP errors
|
||||
self.finished = True
|
||||
raise e
|
||||
|
||||
def _handle_completed_response(self):
|
||||
"""Handle logging for completed responses in async context"""
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
result=self.completed_response,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.now(),
|
||||
cache_hit=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||
"""
|
||||
Synchronous iterator for processing streaming responses from the Responses API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
responses_api_provider_config: BaseResponsesAPIConfig,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
):
|
||||
super().__init__(response, model, responses_api_provider_config, logging_obj)
|
||||
self.stream_iterator = response.iter_lines()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
# Get the next chunk from the stream
|
||||
try:
|
||||
chunk = next(self.stream_iterator)
|
||||
except StopIteration:
|
||||
self.finished = True
|
||||
raise StopIteration
|
||||
|
||||
result = self._process_chunk(chunk)
|
||||
|
||||
if self.finished:
|
||||
raise StopIteration
|
||||
elif result is not None:
|
||||
return result
|
||||
else:
|
||||
return self.__next__()
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
# Handle HTTP errors
|
||||
self.finished = True
|
||||
raise e
|
||||
|
||||
def _handle_completed_response(self):
|
||||
"""Handle logging for completed responses in sync context"""
|
||||
self.logging_obj.success_handler(
|
||||
result=self.completed_response,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.now(),
|
||||
cache_hit=None,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue