diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index a3c5865fa..d24acaecc 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -1,6 +1,11 @@ +import asyncio +import os +import traceback +from typing import Any, Mapping, Optional, Union + +import httpx + import litellm -import httpx, asyncio, traceback, os -from typing import Optional, Union, Mapping, Any # https://www.python-httpx.org/advanced/timeouts _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) @@ -208,6 +213,7 @@ class HTTPHandler: headers: Optional[dict] = None, stream: bool = False, ): + req = self.client.build_request( "POST", url, data=data, json=json, params=params, headers=headers # type: ignore ) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 18b1088ba..523120457 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -491,7 +491,7 @@ def make_sync_call( raise VertexAIError(status_code=response.status_code, message=response.read()) completion_stream = ModelResponseIterator( - streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True + streaming_response=response.iter_bytes(), sync_stream=True ) # LOGGING @@ -811,12 +811,13 @@ class VertexLLM(BaseLLM): endpoint = "generateContent" if stream is True: endpoint = "streamGenerateContent" - - url = ( - "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format( + _gemini_model_name, endpoint, gemini_api_key + ) + else: + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( _gemini_model_name, endpoint, gemini_api_key ) - ) else: auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project @@ -827,7 +828,9 @@ class VertexLLM(BaseLLM): endpoint = "generateContent" if stream is True: endpoint = "streamGenerateContent" - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" + else: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" if ( api_base is not None @@ -840,6 +843,9 @@ class VertexLLM(BaseLLM): else: url = "{}:{}".format(api_base, endpoint) + if stream is True: + url = url + "?alt=sse" + return auth_header, url async def async_streaming( @@ -1268,11 +1274,6 @@ class VertexLLM(BaseLLM): class ModelResponseIterator: def __init__(self, streaming_response, sync_stream: bool): self.streaming_response = streaming_response - if sync_stream: - self.response_iterator = iter(self.streaming_response) - - self.events = ijson.sendable_list() - self.coro = ijson.items_coro(self.events, "item") def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: try: @@ -1322,28 +1323,18 @@ class ModelResponseIterator: # Sync iterator def __iter__(self): + self.response_iterator = self.streaming_response return self def __next__(self): try: chunk = self.response_iterator.__next__() - self.coro.send(chunk) - if self.events: - event = self.events.pop(0) - json_chunk = event - return self.chunk_parser(chunk=json_chunk) - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) + chunk = chunk.decode() + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) except StopIteration: - if self.events: # flush the events - event = self.events.pop(0) # Remove the first event - return self.chunk_parser(chunk=event) raise StopIteration except ValueError as e: raise RuntimeError(f"Error parsing chunk: {e}") @@ -1356,23 +1347,12 @@ class ModelResponseIterator: async def __anext__(self): try: chunk = await self.async_response_iterator.__anext__() - self.coro.send(chunk) - if self.events: - event = self.events.pop(0) - json_chunk = event - return self.chunk_parser(chunk=json_chunk) - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) + chunk = chunk.decode() + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) except StopAsyncIteration: - if self.events: # flush the events - event = self.events.pop(0) # Remove the first event - return self.chunk_parser(chunk=event) raise StopAsyncIteration except ValueError as e: raise RuntimeError(f"Error parsing chunk: {e}") diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 3042e91b3..5cd0e35a9 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -742,7 +742,10 @@ def test_completion_palm_stream(): # test_completion_palm_stream() -@pytest.mark.parametrize("sync_mode", [False]) # True, +@pytest.mark.parametrize( + "sync_mode", + [True, False], +) # , @pytest.mark.asyncio async def test_completion_gemini_stream(sync_mode): try: @@ -807,49 +810,6 @@ async def test_completion_gemini_stream(sync_mode): pytest.fail(f"Error occurred: {e}") -@pytest.mark.asyncio -async def test_acompletion_gemini_stream(): - try: - litellm.set_verbose = True - print("Streaming gemini response") - messages = [ - # {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": "What do you know?", - }, - ] - print("testing gemini streaming") - response = await acompletion( - model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True - ) - print(f"type of response at the top: {response}") - complete_response = "" - idx = 0 - # Add any assertions here to check, the response - async for chunk in response: - print(f"chunk in acompletion gemini: {chunk}") - print(chunk.choices[0].delta) - chunk, finished = streaming_format_tests(idx, chunk) - if finished: - break - print(f"chunk: {chunk}") - complete_response += chunk - idx += 1 - print(f"completion_response: {complete_response}") - if complete_response.strip() == "": - raise Exception("Empty response received") - except litellm.APIError as e: - pass - except litellm.RateLimitError as e: - pass - except Exception as e: - if "429 Resource has been exhausted" in str(e): - pass - else: - pytest.fail(f"Error occurred: {e}") - - # asyncio.run(test_acompletion_gemini_stream()) @@ -1071,7 +1031,7 @@ def test_completion_claude_stream_bad_key(): # test_completion_replicate_stream() -@pytest.mark.parametrize("provider", ["vertex_ai"]) # "vertex_ai_beta" +@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "" def test_vertex_ai_stream(provider): from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials