diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 79d795670..f3640c27c 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -12,6 +12,7 @@ from functools import partial from typing import Any, Callable, List, Literal, Optional, Tuple, Union import httpx # type: ignore +import ijson import requests # type: ignore import litellm @@ -257,7 +258,7 @@ async def make_call( raise VertexAIError(status_code=response.status_code, message=response.text) completion_stream = ModelResponseIterator( - streaming_response=response.aiter_bytes(chunk_size=2056) + streaming_response=response.aiter_bytes(), sync_stream=False ) # LOGGING logging_obj.post_call( @@ -288,7 +289,7 @@ def make_sync_call( raise VertexAIError(status_code=response.status_code, message=response.read()) completion_stream = ModelResponseIterator( - streaming_response=response.iter_bytes(chunk_size=2056) + streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True ) # LOGGING @@ -705,6 +706,25 @@ class VertexLLM(BaseLLM): ### ROUTING (ASYNC, STREAMING, SYNC) if acompletion: + ### ASYNC STREAMING + if stream is True: + return self.async_streaming( + model=model, + messages=messages, + data=json.dumps(data), # type: ignore + api_base=url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, # type: ignore + ) ### ASYNC COMPLETION return self.async_completion( model=model, @@ -916,9 +936,13 @@ class VertexLLM(BaseLLM): class ModelResponseIterator: - def __init__(self, streaming_response): + def __init__(self, streaming_response, sync_stream: bool): self.streaming_response = streaming_response - self.response_iterator = iter(self.streaming_response) + if sync_stream: + self.response_iterator = iter(self.streaming_response) + + self.events = ijson.sendable_list() + self.coro = ijson.items_coro(self.events, "item") def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: try: @@ -970,10 +994,21 @@ class ModelResponseIterator: def __next__(self): try: - chunk = next(self.response_iterator) - chunk = chunk.decode() - json_chunk = json.loads(chunk) - return self.chunk_parser(chunk=json_chunk) + chunk = self.response_iterator.__next__() + self.coro.send(chunk) + if self.events: + event = self.events[0] + json_chunk = event + self.events.clear() + return self.chunk_parser(chunk=json_chunk) + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) except StopIteration: raise StopIteration except ValueError as e: @@ -987,9 +1022,20 @@ class ModelResponseIterator: async def __anext__(self): try: chunk = await self.async_response_iterator.__anext__() - chunk = chunk.decode() - json_chunk = json.loads(chunk) - return self.chunk_parser(chunk=json_chunk) + self.coro.send(chunk) + if self.events: + event = self.events[0] + json_chunk = event + self.events.clear() + return self.chunk_parser(chunk=json_chunk) + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) except StopAsyncIteration: raise StopAsyncIteration except ValueError as e: