mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
Merge 6fda0365b8
into b82af5b826
This commit is contained in:
commit
305e1aed2f
1 changed files with 25 additions and 64 deletions
|
@ -1519,7 +1519,6 @@ class VertexLLM(VertexBase):
|
|||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
self.streaming_response = streaming_response
|
||||
self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json"
|
||||
self.accumulated_json = ""
|
||||
self.sent_first_chunk = False
|
||||
|
||||
|
@ -1596,78 +1595,38 @@ class ModelResponseIterator:
|
|||
self.response_iterator = self.streaming_response
|
||||
return self
|
||||
|
||||
def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk:
|
||||
chunk = chunk.strip()
|
||||
try:
|
||||
json_chunk = json.loads(chunk)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
if (
|
||||
self.sent_first_chunk is False
|
||||
): # only check for accumulated json, on first chunk, else raise error. Prevent real errors from being masked.
|
||||
self.chunk_type = "accumulated_json"
|
||||
return self.handle_accumulated_json_chunk(chunk=chunk)
|
||||
raise e
|
||||
|
||||
if self.sent_first_chunk is False:
|
||||
self.sent_first_chunk = True
|
||||
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
|
||||
def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
|
||||
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||
message = chunk.replace("\n\n", "")
|
||||
|
||||
# Accumulate JSON data
|
||||
self.accumulated_json += message
|
||||
|
||||
# Try to parse the accumulated JSON
|
||||
def _flush(self) -> Optional[GenericStreamingChunk]:
|
||||
try:
|
||||
_data = json.loads(self.accumulated_json)
|
||||
self.accumulated_json = "" # reset after successful parsing
|
||||
self.accumulated_json = ""
|
||||
return self.chunk_parser(chunk=_data)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON yet, continue to the next event
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
return None
|
||||
|
||||
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||
if len(chunk) > 0:
|
||||
"""
|
||||
Check if initial chunk valid json
|
||||
- if partial json -> enter accumulated json logic
|
||||
- if valid - continue
|
||||
"""
|
||||
if self.chunk_type == "valid_json":
|
||||
return self.handle_valid_json_chunk(chunk=chunk)
|
||||
elif self.chunk_type == "accumulated_json":
|
||||
return self.handle_accumulated_json_chunk(chunk=chunk)
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||
self.accumulated_json += chunk
|
||||
chunk = self._flush()
|
||||
if chunk:
|
||||
return chunk
|
||||
# If it's not valid JSON yet, continue to the next event
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
if self.chunk_type == "accumulated_json" and self.accumulated_json:
|
||||
return self.handle_accumulated_json_chunk(chunk="")
|
||||
if self.accumulated_json:
|
||||
chunk = self._flush()
|
||||
if chunk:
|
||||
return chunk
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
@ -1688,8 +1647,10 @@ class ModelResponseIterator:
|
|||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
if self.chunk_type == "accumulated_json" and self.accumulated_json:
|
||||
return self.handle_accumulated_json_chunk(chunk="")
|
||||
if self.accumulated_json:
|
||||
chunk = self._flush()
|
||||
if chunk:
|
||||
return chunk
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue