This commit is contained in:
Krrish Dholakia 2024-06-28 10:38:19 -07:00
parent 2bd8205ef5
commit a7122f91a1
3 changed files with 36 additions and 90 deletions

View file

@ -1,6 +1,11 @@
import asyncio
import os
import traceback
from typing import Any, Mapping, Optional, Union
import httpx
import litellm import litellm
import httpx, asyncio, traceback, os
from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts # https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
@ -208,6 +213,7 @@ class HTTPHandler:
headers: Optional[dict] = None, headers: Optional[dict] = None,
stream: bool = False, stream: bool = False,
): ):
req = self.client.build_request( req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore "POST", url, data=data, json=json, params=params, headers=headers # type: ignore
) )

View file

@ -491,7 +491,7 @@ def make_sync_call(
raise VertexAIError(status_code=response.status_code, message=response.read()) raise VertexAIError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True streaming_response=response.iter_bytes(), sync_stream=True
) )
# LOGGING # LOGGING
@ -811,12 +811,13 @@ class VertexLLM(BaseLLM):
endpoint = "generateContent" endpoint = "generateContent"
if stream is True: if stream is True:
endpoint = "streamGenerateContent" endpoint = "streamGenerateContent"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
url = ( _gemini_model_name, endpoint, gemini_api_key
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( )
else:
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key _gemini_model_name, endpoint, gemini_api_key
) )
)
else: else:
auth_header, vertex_project = self._ensure_access_token( auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials, project_id=vertex_project
@ -827,7 +828,9 @@ class VertexLLM(BaseLLM):
endpoint = "generateContent" endpoint = "generateContent"
if stream is True: if stream is True:
endpoint = "streamGenerateContent" endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if ( if (
api_base is not None api_base is not None
@ -840,6 +843,9 @@ class VertexLLM(BaseLLM):
else: else:
url = "{}:{}".format(api_base, endpoint) url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url return auth_header, url
async def async_streaming( async def async_streaming(
@ -1268,11 +1274,6 @@ class VertexLLM(BaseLLM):
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool): def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response self.streaming_response = streaming_response
if sync_stream:
self.response_iterator = iter(self.streaming_response)
self.events = ijson.sendable_list()
self.coro = ijson.items_coro(self.events, "item")
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try: try:
@ -1322,28 +1323,18 @@ class ModelResponseIterator:
# Sync iterator # Sync iterator
def __iter__(self): def __iter__(self):
self.response_iterator = self.streaming_response
return self return self
def __next__(self): def __next__(self):
try: try:
chunk = self.response_iterator.__next__() chunk = self.response_iterator.__next__()
self.coro.send(chunk) chunk = chunk.decode()
if self.events: chunk = chunk.replace("data:", "")
event = self.events.pop(0) chunk = chunk.strip()
json_chunk = event json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk) return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration: except StopIteration:
if self.events: # flush the events
event = self.events.pop(0) # Remove the first event
return self.chunk_parser(chunk=event)
raise StopIteration raise StopIteration
except ValueError as e: except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}") raise RuntimeError(f"Error parsing chunk: {e}")
@ -1356,23 +1347,12 @@ class ModelResponseIterator:
async def __anext__(self): async def __anext__(self):
try: try:
chunk = await self.async_response_iterator.__anext__() chunk = await self.async_response_iterator.__anext__()
self.coro.send(chunk) chunk = chunk.decode()
if self.events: chunk = chunk.replace("data:", "")
event = self.events.pop(0) chunk = chunk.strip()
json_chunk = event json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk) return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration: except StopAsyncIteration:
if self.events: # flush the events
event = self.events.pop(0) # Remove the first event
return self.chunk_parser(chunk=event)
raise StopAsyncIteration raise StopAsyncIteration
except ValueError as e: except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}") raise RuntimeError(f"Error parsing chunk: {e}")

View file

@ -742,7 +742,10 @@ def test_completion_palm_stream():
# test_completion_palm_stream() # test_completion_palm_stream()
@pytest.mark.parametrize("sync_mode", [False]) # True, @pytest.mark.parametrize(
"sync_mode",
[True, False],
) # ,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_gemini_stream(sync_mode): async def test_completion_gemini_stream(sync_mode):
try: try:
@ -807,49 +810,6 @@ async def test_completion_gemini_stream(sync_mode):
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_gemini_stream():
try:
litellm.set_verbose = True
print("Streaming gemini response")
messages = [
# {"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "What do you know?",
},
]
print("testing gemini streaming")
response = await acompletion(
model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True
)
print(f"type of response at the top: {response}")
complete_response = ""
idx = 0
# Add any assertions here to check, the response
async for chunk in response:
print(f"chunk in acompletion gemini: {chunk}")
print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
print(f"chunk: {chunk}")
complete_response += chunk
idx += 1
print(f"completion_response: {complete_response}")
if complete_response.strip() == "":
raise Exception("Empty response received")
except litellm.APIError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
if "429 Resource has been exhausted" in str(e):
pass
else:
pytest.fail(f"Error occurred: {e}")
# asyncio.run(test_acompletion_gemini_stream()) # asyncio.run(test_acompletion_gemini_stream())
@ -1071,7 +1031,7 @@ def test_completion_claude_stream_bad_key():
# test_completion_replicate_stream() # test_completion_replicate_stream()
@pytest.mark.parametrize("provider", ["vertex_ai"]) # "vertex_ai_beta" @pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # ""
def test_vertex_ai_stream(provider): def test_vertex_ai_stream(provider):
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials