fix(vertex_httpx.py): support streaming via httpx client

This commit is contained in:
Krrish Dholakia 2024-06-12 19:55:14 -07:00
parent 3b913443fe
commit 3955b058ed
7 changed files with 283 additions and 26 deletions

View file

@ -605,6 +605,7 @@ provider_list: List = [
"together_ai",
"openrouter",
"vertex_ai",
"vertex_ai_beta",
"palm",
"gemini",
"ai21",

View file

@ -1,3 +1,7 @@
# What is this?
## httpx client for vertex ai calls
## Initial implementation - covers gemini + image gen calls
from functools import partial
import os, types
import json
from enum import Enum
@ -17,6 +21,86 @@ from litellm.types.llms.vertex_ai import (
GenerateContentResponseBody,
)
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
from litellm.types.utils import GenericStreamingChunk
from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
)
class VertexGeminiConfig:
def __init__(self) -> None:
pass
def supports_system_message(self) -> bool:
"""
Not all gemini models support system instructions
"""
return True
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = AsyncHTTPHandler() # Create a new client if none provided
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise VertexAIError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_bytes(chunk_size=2056)
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise VertexAIError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_bytes(chunk_size=2056)
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
class VertexAIError(Exception):
@ -46,7 +130,6 @@ class VertexLLM(BaseLLM):
model: str,
response: httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
@ -77,7 +160,7 @@ class VertexLLM(BaseLLM):
status_code=422,
)
model_response.choices = []
model_response.choices = [] # type: ignore
## GET MODEL ##
model_response.model = model
@ -190,6 +273,16 @@ class VertexLLM(BaseLLM):
return self._credentials.token, self.project_id
async def async_streaming(
self,
):
pass
async def async_completion(
self,
):
pass
def completion(
self,
model: str,
@ -214,7 +307,7 @@ class VertexLLM(BaseLLM):
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
stream = optional_params.pop("stream", None)
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
### SET RUNTIME ENDPOINT ###
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
@ -251,6 +344,26 @@ class VertexLLM(BaseLLM):
},
)
## SYNC STREAMING CALL ##
if stream is not None and stream is True:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=url,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
## COMPLETION CALL ##
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
@ -274,7 +387,6 @@ class VertexLLM(BaseLLM):
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
@ -421,3 +533,84 @@ class VertexLLM(BaseLLM):
model_response.data = _response_data
return model_response
class ModelResponseIterator:
def __init__(self, streaming_response):
self.streaming_response = streaming_response
self.response_iterator = iter(self.streaming_response)
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
gemini_chunk = processed_chunk["candidates"][0]
if (
"content" in gemini_chunk
and "text" in gemini_chunk["content"]["parts"][0]
):
text = gemini_chunk["content"]["parts"][0]["text"]
if "finishReason" in gemini_chunk:
finish_reason = map_finish_reason(
finish_reason=gemini_chunk["finishReason"]
)
is_finished = True
if "usageMetadata" in processed_chunk:
usage = ChatCompletionUsageBlock(
prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"],
completion_tokens=processed_chunk["usageMetadata"][
"candidatesTokenCount"
],
total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"],
)
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = next(self.response_iterator)
chunk = chunk.decode()
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
chunk = chunk.decode()
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")

View file

@ -1875,6 +1875,42 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "vertex_ai_beta":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
new_params = deepcopy(optional_params)
response = vertex_chat_completion.completion( # type: ignore
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
@ -1911,26 +1947,6 @@ def completion(
logging_obj=logging,
acompletion=acompletion,
)
elif (
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
model_response = vertex_chat_completion.completion( # type: ignore
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
)
else:
model_response = vertex_ai.completion(
model=model,

View file

@ -1029,7 +1029,8 @@ def test_completion_claude_stream_bad_key():
# test_completion_replicate_stream()
def test_vertex_ai_stream():
@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"])
def test_vertex_ai_stream(provider):
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials
load_vertex_ai_credentials()
@ -1042,7 +1043,7 @@ def test_vertex_ai_stream():
try:
print("making request", model)
response = completion(
model=model,
model="{}/{}".format(provider, model),
messages=[
{"role": "user", "content": "write 10 line code code for saying hi"}
],

View file

@ -323,3 +323,9 @@ class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionToolCallChunk]
role: Literal["assistant"]
class ChatCompletionUsageBlock(TypedDict):
prompt_tokens: int
completion_tokens: int
total_tokens: int

View file

@ -1,6 +1,8 @@
from typing import List, Optional, Union, Dict, Tuple, Literal
from typing_extensions import TypedDict
from enum import Enum
from typing_extensions import override, Required, Dict
from .llms.openai import ChatCompletionUsageBlock, ChatCompletionToolCallChunk
class LiteLLMCommonStrings(Enum):
@ -37,3 +39,12 @@ class ModelInfo(TypedDict):
"completion", "embedding", "image_generation", "chat", "audio_transcription"
]
supported_openai_params: Optional[List[str]]
class GenericStreamingChunk(TypedDict):
text: Required[str]
tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool]
finish_reason: Required[str]
usage: Optional[ChatCompletionUsageBlock]
index: int

View file

@ -11223,6 +11223,34 @@ class CustomStreamWrapper:
)
else:
completion_obj["content"] = str(chunk)
elif self.custom_llm_provider and (
self.custom_llm_provider == "vertex_ai_beta"
):
from litellm.types.utils import (
GenericStreamingChunk as UtilsStreamingChunk,
)
if self.received_finish_reason is not None:
raise StopIteration
response_obj: UtilsStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["prompt_tokens"],
completion_tokens=response_obj["usage"]["completion_tokens"],
total_tokens=response_obj["usage"]["total_tokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
import proto # type: ignore
@ -11900,6 +11928,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "ollama"
or self.custom_llm_provider == "ollama_chat"
or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "vertex_ai_beta"
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "replicate"