forked from phoenix/litellm-mirror
fix(vertex_httpx.py): support streaming via httpx client
This commit is contained in:
parent
3b913443fe
commit
3955b058ed
7 changed files with 283 additions and 26 deletions
|
@ -605,6 +605,7 @@ provider_list: List = [
|
|||
"together_ai",
|
||||
"openrouter",
|
||||
"vertex_ai",
|
||||
"vertex_ai_beta",
|
||||
"palm",
|
||||
"gemini",
|
||||
"ai21",
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
# What is this?
|
||||
## httpx client for vertex ai calls
|
||||
## Initial implementation - covers gemini + image gen calls
|
||||
from functools import partial
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
|
@ -17,6 +21,86 @@ from litellm.types.llms.vertex_ai import (
|
|||
GenerateContentResponseBody,
|
||||
)
|
||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionUsageBlock,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
)
|
||||
|
||||
|
||||
class VertexGeminiConfig:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def supports_system_message(self) -> bool:
|
||||
"""
|
||||
Not all gemini models support system instructions
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler() # Create a new client if none provided
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise VertexAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_bytes(chunk_size=2056)
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
if client is None:
|
||||
client = HTTPHandler() # Create a new client if none provided
|
||||
|
||||
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise VertexAIError(status_code=response.status_code, message=response.read())
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_bytes(chunk_size=2056)
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -46,7 +130,6 @@ class VertexLLM(BaseLLM):
|
|||
model: str,
|
||||
response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
|
@ -77,7 +160,7 @@ class VertexLLM(BaseLLM):
|
|||
status_code=422,
|
||||
)
|
||||
|
||||
model_response.choices = []
|
||||
model_response.choices = [] # type: ignore
|
||||
|
||||
## GET MODEL ##
|
||||
model_response.model = model
|
||||
|
@ -190,6 +273,16 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
return self._credentials.token, self.project_id
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
):
|
||||
pass
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
):
|
||||
pass
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -214,7 +307,7 @@ class VertexLLM(BaseLLM):
|
|||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
stream = optional_params.pop("stream", None)
|
||||
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
|
||||
|
@ -251,6 +344,26 @@ class VertexLLM(BaseLLM):
|
|||
},
|
||||
)
|
||||
|
||||
## SYNC STREAMING CALL ##
|
||||
if stream is not None and stream is True:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_sync_call,
|
||||
client=None,
|
||||
api_base=url,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
## COMPLETION CALL ##
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
|
@ -274,7 +387,6 @@ class VertexLLM(BaseLLM):
|
|||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key="",
|
||||
|
@ -421,3 +533,84 @@ class VertexLLM(BaseLLM):
|
|||
model_response.data = _response_data
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = iter(self.streaming_response)
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
|
||||
gemini_chunk = processed_chunk["candidates"][0]
|
||||
|
||||
if (
|
||||
"content" in gemini_chunk
|
||||
and "text" in gemini_chunk["content"]["parts"][0]
|
||||
):
|
||||
text = gemini_chunk["content"]["parts"][0]["text"]
|
||||
|
||||
if "finishReason" in gemini_chunk:
|
||||
finish_reason = map_finish_reason(
|
||||
finish_reason=gemini_chunk["finishReason"]
|
||||
)
|
||||
is_finished = True
|
||||
|
||||
if "usageMetadata" in processed_chunk:
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"],
|
||||
completion_tokens=processed_chunk["usageMetadata"][
|
||||
"candidatesTokenCount"
|
||||
],
|
||||
total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"],
|
||||
)
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=0,
|
||||
)
|
||||
return returned_chunk
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = next(self.response_iterator)
|
||||
chunk = chunk.decode()
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
chunk = chunk.decode()
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e}")
|
||||
|
|
|
@ -1875,6 +1875,42 @@ def completion(
|
|||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "vertex_ai_beta":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
or optional_params.pop("vertex_ai_project", None)
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.pop("vertex_location", None)
|
||||
or optional_params.pop("vertex_ai_location", None)
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = (
|
||||
optional_params.pop("vertex_credentials", None)
|
||||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
new_params = deepcopy(optional_params)
|
||||
response = vertex_chat_completion.completion( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=new_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_credentials=vertex_credentials,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
|
@ -1911,26 +1947,6 @@ def completion(
|
|||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
elif (
|
||||
model in litellm.vertex_language_models
|
||||
or model in litellm.vertex_vision_models
|
||||
):
|
||||
model_response = vertex_chat_completion.completion( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=new_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_credentials=vertex_credentials,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
model_response = vertex_ai.completion(
|
||||
model=model,
|
||||
|
|
|
@ -1029,7 +1029,8 @@ def test_completion_claude_stream_bad_key():
|
|||
# test_completion_replicate_stream()
|
||||
|
||||
|
||||
def test_vertex_ai_stream():
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"])
|
||||
def test_vertex_ai_stream(provider):
|
||||
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||
|
||||
load_vertex_ai_credentials()
|
||||
|
@ -1042,7 +1043,7 @@ def test_vertex_ai_stream():
|
|||
try:
|
||||
print("making request", model)
|
||||
response = completion(
|
||||
model=model,
|
||||
model="{}/{}".format(provider, model),
|
||||
messages=[
|
||||
{"role": "user", "content": "write 10 line code code for saying hi"}
|
||||
],
|
||||
|
|
|
@ -323,3 +323,9 @@ class ChatCompletionResponseMessage(TypedDict, total=False):
|
|||
content: Optional[str]
|
||||
tool_calls: List[ChatCompletionToolCallChunk]
|
||||
role: Literal["assistant"]
|
||||
|
||||
|
||||
class ChatCompletionUsageBlock(TypedDict):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||
from typing_extensions import TypedDict
|
||||
from enum import Enum
|
||||
from typing_extensions import override, Required, Dict
|
||||
from .llms.openai import ChatCompletionUsageBlock, ChatCompletionToolCallChunk
|
||||
|
||||
|
||||
class LiteLLMCommonStrings(Enum):
|
||||
|
@ -37,3 +39,12 @@ class ModelInfo(TypedDict):
|
|||
"completion", "embedding", "image_generation", "chat", "audio_transcription"
|
||||
]
|
||||
supported_openai_params: Optional[List[str]]
|
||||
|
||||
|
||||
class GenericStreamingChunk(TypedDict):
|
||||
text: Required[str]
|
||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
is_finished: Required[bool]
|
||||
finish_reason: Required[str]
|
||||
usage: Optional[ChatCompletionUsageBlock]
|
||||
index: int
|
||||
|
|
|
@ -11223,6 +11223,34 @@ class CustomStreamWrapper:
|
|||
)
|
||||
else:
|
||||
completion_obj["content"] = str(chunk)
|
||||
elif self.custom_llm_provider and (
|
||||
self.custom_llm_provider == "vertex_ai_beta"
|
||||
):
|
||||
from litellm.types.utils import (
|
||||
GenericStreamingChunk as UtilsStreamingChunk,
|
||||
)
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
raise StopIteration
|
||||
response_obj: UtilsStreamingChunk = chunk
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"]["prompt_tokens"],
|
||||
completion_tokens=response_obj["usage"]["completion_tokens"],
|
||||
total_tokens=response_obj["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
||||
import proto # type: ignore
|
||||
|
||||
|
@ -11900,6 +11928,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "ollama"
|
||||
or self.custom_llm_provider == "ollama_chat"
|
||||
or self.custom_llm_provider == "vertex_ai"
|
||||
or self.custom_llm_provider == "vertex_ai_beta"
|
||||
or self.custom_llm_provider == "sagemaker"
|
||||
or self.custom_llm_provider == "gemini"
|
||||
or self.custom_llm_provider == "replicate"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue