This commit is contained in:
Krrish Dholakia 2024-06-28 10:38:19 -07:00
parent 2bd8205ef5
commit a7122f91a1
3 changed files with 36 additions and 90 deletions

View file

@ -1,6 +1,11 @@
import asyncio
import os
import traceback
from typing import Any, Mapping, Optional, Union
import httpx
import litellm
import httpx, asyncio, traceback, os
from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
@ -208,6 +213,7 @@ class HTTPHandler:
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)

View file

@ -491,7 +491,7 @@ def make_sync_call(
raise VertexAIError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
streaming_response=response.iter_bytes(), sync_stream=True
)
# LOGGING
@ -811,11 +811,12 @@ class VertexLLM(BaseLLM):
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
auth_header, vertex_project = self._ensure_access_token(
@ -827,6 +828,8 @@ class VertexLLM(BaseLLM):
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if (
@ -840,6 +843,9 @@ class VertexLLM(BaseLLM):
else:
url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url
async def async_streaming(
@ -1268,11 +1274,6 @@ class VertexLLM(BaseLLM):
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
if sync_stream:
self.response_iterator = iter(self.streaming_response)
self.events = ijson.sendable_list()
self.coro = ijson.items_coro(self.events, "item")
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
@ -1322,28 +1323,18 @@ class ModelResponseIterator:
# Sync iterator
def __iter__(self):
self.response_iterator = self.streaming_response
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
self.coro.send(chunk)
if self.events:
event = self.events.pop(0)
json_chunk = event
chunk = chunk.decode()
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
if self.events: # flush the events
event = self.events.pop(0) # Remove the first event
return self.chunk_parser(chunk=event)
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")
@ -1356,23 +1347,12 @@ class ModelResponseIterator:
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
self.coro.send(chunk)
if self.events:
event = self.events.pop(0)
json_chunk = event
chunk = chunk.decode()
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
if self.events: # flush the events
event = self.events.pop(0) # Remove the first event
return self.chunk_parser(chunk=event)
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")

View file

@ -742,7 +742,10 @@ def test_completion_palm_stream():
# test_completion_palm_stream()
@pytest.mark.parametrize("sync_mode", [False]) # True,
@pytest.mark.parametrize(
"sync_mode",
[True, False],
) # ,
@pytest.mark.asyncio
async def test_completion_gemini_stream(sync_mode):
try:
@ -807,49 +810,6 @@ async def test_completion_gemini_stream(sync_mode):
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_gemini_stream():
try:
litellm.set_verbose = True
print("Streaming gemini response")
messages = [
# {"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "What do you know?",
},
]
print("testing gemini streaming")
response = await acompletion(
model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True
)
print(f"type of response at the top: {response}")
complete_response = ""
idx = 0
# Add any assertions here to check, the response
async for chunk in response:
print(f"chunk in acompletion gemini: {chunk}")
print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
print(f"chunk: {chunk}")
complete_response += chunk
idx += 1
print(f"completion_response: {complete_response}")
if complete_response.strip() == "":
raise Exception("Empty response received")
except litellm.APIError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
if "429 Resource has been exhausted" in str(e):
pass
else:
pytest.fail(f"Error occurred: {e}")
# asyncio.run(test_acompletion_gemini_stream())
@ -1071,7 +1031,7 @@ def test_completion_claude_stream_bad_key():
# test_completion_replicate_stream()
@pytest.mark.parametrize("provider", ["vertex_ai"]) # "vertex_ai_beta"
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # ""
def test_vertex_ai_stream(provider):
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials