forked from phoenix/litellm-mirror
fix(support-'alt=sse'-param): Fixes https://github.com/BerriAI/litellm/issues/4459
This commit is contained in:
parent
2bd8205ef5
commit
a7122f91a1
3 changed files with 36 additions and 90 deletions
|
@ -1,6 +1,11 @@
|
|||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
import httpx, asyncio, traceback, os
|
||||
from typing import Optional, Union, Mapping, Any
|
||||
|
||||
# https://www.python-httpx.org/advanced/timeouts
|
||||
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||
|
@ -208,6 +213,7 @@ class HTTPHandler:
|
|||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||
)
|
||||
|
|
|
@ -491,7 +491,7 @@ def make_sync_call(
|
|||
raise VertexAIError(status_code=response.status_code, message=response.read())
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
|
||||
streaming_response=response.iter_bytes(), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
|
@ -811,11 +811,12 @@ class VertexLLM(BaseLLM):
|
|||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
|
||||
url = (
|
||||
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
else:
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
else:
|
||||
auth_header, vertex_project = self._ensure_access_token(
|
||||
|
@ -827,6 +828,8 @@ class VertexLLM(BaseLLM):
|
|||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
|
||||
else:
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
if (
|
||||
|
@ -840,6 +843,9 @@ class VertexLLM(BaseLLM):
|
|||
else:
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
|
||||
if stream is True:
|
||||
url = url + "?alt=sse"
|
||||
|
||||
return auth_header, url
|
||||
|
||||
async def async_streaming(
|
||||
|
@ -1268,11 +1274,6 @@ class VertexLLM(BaseLLM):
|
|||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
self.streaming_response = streaming_response
|
||||
if sync_stream:
|
||||
self.response_iterator = iter(self.streaming_response)
|
||||
|
||||
self.events = ijson.sendable_list()
|
||||
self.coro = ijson.items_coro(self.events, "item")
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
|
@ -1322,28 +1323,18 @@ class ModelResponseIterator:
|
|||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
self.response_iterator = self.streaming_response
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
self.coro.send(chunk)
|
||||
if self.events:
|
||||
event = self.events.pop(0)
|
||||
json_chunk = event
|
||||
chunk = chunk.decode()
|
||||
chunk = chunk.replace("data:", "")
|
||||
chunk = chunk.strip()
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopIteration:
|
||||
if self.events: # flush the events
|
||||
event = self.events.pop(0) # Remove the first event
|
||||
return self.chunk_parser(chunk=event)
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e}")
|
||||
|
@ -1356,23 +1347,12 @@ class ModelResponseIterator:
|
|||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
self.coro.send(chunk)
|
||||
if self.events:
|
||||
event = self.events.pop(0)
|
||||
json_chunk = event
|
||||
chunk = chunk.decode()
|
||||
chunk = chunk.replace("data:", "")
|
||||
chunk = chunk.strip()
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
if self.events: # flush the events
|
||||
event = self.events.pop(0) # Remove the first event
|
||||
return self.chunk_parser(chunk=event)
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e}")
|
||||
|
|
|
@ -742,7 +742,10 @@ def test_completion_palm_stream():
|
|||
# test_completion_palm_stream()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [False]) # True,
|
||||
@pytest.mark.parametrize(
|
||||
"sync_mode",
|
||||
[True, False],
|
||||
) # ,
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_gemini_stream(sync_mode):
|
||||
try:
|
||||
|
@ -807,49 +810,6 @@ async def test_completion_gemini_stream(sync_mode):
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_gemini_stream():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
print("Streaming gemini response")
|
||||
messages = [
|
||||
# {"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What do you know?",
|
||||
},
|
||||
]
|
||||
print("testing gemini streaming")
|
||||
response = await acompletion(
|
||||
model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True
|
||||
)
|
||||
print(f"type of response at the top: {response}")
|
||||
complete_response = ""
|
||||
idx = 0
|
||||
# Add any assertions here to check, the response
|
||||
async for chunk in response:
|
||||
print(f"chunk in acompletion gemini: {chunk}")
|
||||
print(chunk.choices[0].delta)
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
break
|
||||
print(f"chunk: {chunk}")
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
print(f"completion_response: {complete_response}")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
if "429 Resource has been exhausted" in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# asyncio.run(test_acompletion_gemini_stream())
|
||||
|
||||
|
||||
|
@ -1071,7 +1031,7 @@ def test_completion_claude_stream_bad_key():
|
|||
# test_completion_replicate_stream()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai"]) # "vertex_ai_beta"
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # ""
|
||||
def test_vertex_ai_stream(provider):
|
||||
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue