forked from phoenix/litellm-mirror
fix(vertex_httpx.py): support async streaming for google ai studio gemini
This commit is contained in:
parent
be66800a98
commit
e92570534c
1 changed files with 57 additions and 11 deletions
|
@ -12,6 +12,7 @@ from functools import partial
|
||||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
import ijson
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -257,7 +258,7 @@ async def make_call(
|
||||||
raise VertexAIError(status_code=response.status_code, message=response.text)
|
raise VertexAIError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.aiter_bytes(chunk_size=2056)
|
streaming_response=response.aiter_bytes(), sync_stream=False
|
||||||
)
|
)
|
||||||
# LOGGING
|
# LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -288,7 +289,7 @@ def make_sync_call(
|
||||||
raise VertexAIError(status_code=response.status_code, message=response.read())
|
raise VertexAIError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.iter_bytes(chunk_size=2056)
|
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
@ -705,6 +706,25 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
if acompletion:
|
if acompletion:
|
||||||
|
### ASYNC STREAMING
|
||||||
|
if stream is True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=json.dumps(data), # type: ignore
|
||||||
|
api_base=url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client, # type: ignore
|
||||||
|
)
|
||||||
### ASYNC COMPLETION
|
### ASYNC COMPLETION
|
||||||
return self.async_completion(
|
return self.async_completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -916,9 +936,13 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, streaming_response):
|
def __init__(self, streaming_response, sync_stream: bool):
|
||||||
self.streaming_response = streaming_response
|
self.streaming_response = streaming_response
|
||||||
self.response_iterator = iter(self.streaming_response)
|
if sync_stream:
|
||||||
|
self.response_iterator = iter(self.streaming_response)
|
||||||
|
|
||||||
|
self.events = ijson.sendable_list()
|
||||||
|
self.coro = ijson.items_coro(self.events, "item")
|
||||||
|
|
||||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
try:
|
try:
|
||||||
|
@ -970,10 +994,21 @@ class ModelResponseIterator:
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
chunk = next(self.response_iterator)
|
chunk = self.response_iterator.__next__()
|
||||||
chunk = chunk.decode()
|
self.coro.send(chunk)
|
||||||
json_chunk = json.loads(chunk)
|
if self.events:
|
||||||
return self.chunk_parser(chunk=json_chunk)
|
event = self.events[0]
|
||||||
|
json_chunk = event
|
||||||
|
self.events.clear()
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -987,9 +1022,20 @@ class ModelResponseIterator:
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
chunk = await self.async_response_iterator.__anext__()
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
chunk = chunk.decode()
|
self.coro.send(chunk)
|
||||||
json_chunk = json.loads(chunk)
|
if self.events:
|
||||||
return self.chunk_parser(chunk=json_chunk)
|
event = self.events[0]
|
||||||
|
json_chunk = event
|
||||||
|
self.events.clear()
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue