diff --git a/litellm/__init__.py b/litellm/__init__.py index e18be347d..19c3bcca6 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -605,6 +605,7 @@ provider_list: List = [ "together_ai", "openrouter", "vertex_ai", + "vertex_ai_beta", "palm", "gemini", "ai21", diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index acf79bebb..70a408c2b 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -1,3 +1,7 @@ +# What is this? +## httpx client for vertex ai calls +## Initial implementation - covers gemini + image gen calls +from functools import partial import os, types import json from enum import Enum @@ -17,6 +21,86 @@ from litellm.types.llms.vertex_ai import ( GenerateContentResponseBody, ) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history +from litellm.types.utils import GenericStreamingChunk +from litellm.types.llms.openai import ( + ChatCompletionUsageBlock, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, +) + + +class VertexGeminiConfig: + def __init__(self) -> None: + pass + + def supports_system_message(self) -> bool: + """ + Not all gemini models support system instructions + """ + return True + + +async def make_call( + client: Optional[AsyncHTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = AsyncHTTPHandler() # Create a new client if none provided + + response = await client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise VertexAIError(status_code=response.status_code, message=response.text) + + completion_stream = ModelResponseIterator( + streaming_response=response.aiter_bytes(chunk_size=2056) + ) + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = HTTPHandler() # Create a new client if none provided + + response = client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise VertexAIError(status_code=response.status_code, message=response.read()) + + completion_stream = ModelResponseIterator( + streaming_response=response.iter_bytes(chunk_size=2056) + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream class VertexAIError(Exception): @@ -46,7 +130,6 @@ class VertexLLM(BaseLLM): model: str, response: httpx.Response, model_response: ModelResponse, - stream: bool, logging_obj: litellm.utils.Logging, optional_params: dict, api_key: str, @@ -77,7 +160,7 @@ class VertexLLM(BaseLLM): status_code=422, ) - model_response.choices = [] + model_response.choices = [] # type: ignore ## GET MODEL ## model_response.model = model @@ -190,6 +273,16 @@ class VertexLLM(BaseLLM): return self._credentials.token, self.project_id + async def async_streaming( + self, + ): + pass + + async def async_completion( + self, + ): + pass + def completion( self, model: str, @@ -214,7 +307,7 @@ class VertexLLM(BaseLLM): credentials=vertex_credentials, project_id=vertex_project ) vertex_location = self.get_vertex_region(vertex_region=vertex_location) - stream = optional_params.pop("stream", None) + stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore ### SET RUNTIME ENDPOINT ### url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent" @@ -251,6 +344,26 @@ class VertexLLM(BaseLLM): }, ) + ## SYNC STREAMING CALL ## + if stream is not None and stream is True: + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_sync_call, + client=None, + api_base=url, + headers=headers, # type: ignore + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + + return streaming_response ## COMPLETION CALL ## if client is None or isinstance(client, AsyncHTTPHandler): _params = {} @@ -274,7 +387,6 @@ class VertexLLM(BaseLLM): model=model, response=response, model_response=model_response, - stream=stream, logging_obj=logging_obj, optional_params=optional_params, api_key="", @@ -421,3 +533,84 @@ class VertexLLM(BaseLLM): model_response.data = _response_data return model_response + + +class ModelResponseIterator: + def __init__(self, streaming_response): + self.streaming_response = streaming_response + self.response_iterator = iter(self.streaming_response) + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + + gemini_chunk = processed_chunk["candidates"][0] + + if ( + "content" in gemini_chunk + and "text" in gemini_chunk["content"]["parts"][0] + ): + text = gemini_chunk["content"]["parts"][0]["text"] + + if "finishReason" in gemini_chunk: + finish_reason = map_finish_reason( + finish_reason=gemini_chunk["finishReason"] + ) + is_finished = True + + if "usageMetadata" in processed_chunk: + usage = ChatCompletionUsageBlock( + prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"], + completion_tokens=processed_chunk["usageMetadata"][ + "candidatesTokenCount" + ], + total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"], + ) + + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=0, + ) + return returned_chunk + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + try: + chunk = next(self.response_iterator) + chunk = chunk.decode() + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + chunk = chunk.decode() + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e}") diff --git a/litellm/main.py b/litellm/main.py index 63b86b43b..16fd394f8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1875,6 +1875,42 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "vertex_ai_beta": + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) + new_params = deepcopy(optional_params) + response = vertex_chat_completion.completion( # type: ignore + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + ) + elif custom_llm_provider == "vertex_ai": vertex_ai_project = ( optional_params.pop("vertex_project", None) @@ -1911,26 +1947,6 @@ def completion( logging_obj=logging, acompletion=acompletion, ) - elif ( - model in litellm.vertex_language_models - or model in litellm.vertex_vision_models - ): - model_response = vertex_chat_completion.completion( # type: ignore - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, - ) else: model_response = vertex_ai.completion( model=model, diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index ac107d28e..c23dbb7d9 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1029,7 +1029,8 @@ def test_completion_claude_stream_bad_key(): # test_completion_replicate_stream() -def test_vertex_ai_stream(): +@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"]) +def test_vertex_ai_stream(provider): from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials load_vertex_ai_credentials() @@ -1042,7 +1043,7 @@ def test_vertex_ai_stream(): try: print("making request", model) response = completion( - model=model, + model="{}/{}".format(provider, model), messages=[ {"role": "user", "content": "write 10 line code code for saying hi"} ], diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 66aec4906..88f498ede 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -323,3 +323,9 @@ class ChatCompletionResponseMessage(TypedDict, total=False): content: Optional[str] tool_calls: List[ChatCompletionToolCallChunk] role: Literal["assistant"] + + +class ChatCompletionUsageBlock(TypedDict): + prompt_tokens: int + completion_tokens: int + total_tokens: int diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 2b6aefcf5..1fbb375d3 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1,6 +1,8 @@ from typing import List, Optional, Union, Dict, Tuple, Literal from typing_extensions import TypedDict from enum import Enum +from typing_extensions import override, Required, Dict +from .llms.openai import ChatCompletionUsageBlock, ChatCompletionToolCallChunk class LiteLLMCommonStrings(Enum): @@ -37,3 +39,12 @@ class ModelInfo(TypedDict): "completion", "embedding", "image_generation", "chat", "audio_transcription" ] supported_openai_params: Optional[List[str]] + + +class GenericStreamingChunk(TypedDict): + text: Required[str] + tool_use: Optional[ChatCompletionToolCallChunk] + is_finished: Required[bool] + finish_reason: Required[str] + usage: Optional[ChatCompletionUsageBlock] + index: int diff --git a/litellm/utils.py b/litellm/utils.py index 14041d2b6..f132e3202 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -11223,6 +11223,34 @@ class CustomStreamWrapper: ) else: completion_obj["content"] = str(chunk) + elif self.custom_llm_provider and ( + self.custom_llm_provider == "vertex_ai_beta" + ): + from litellm.types.utils import ( + GenericStreamingChunk as UtilsStreamingChunk, + ) + + if self.received_finish_reason is not None: + raise StopIteration + response_obj: UtilsStreamingChunk = chunk + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + + if ( + self.stream_options + and self.stream_options.get("include_usage", False) is True + and response_obj["usage"] is not None + ): + self.sent_stream_usage = True + model_response.usage = litellm.Usage( + prompt_tokens=response_obj["usage"]["prompt_tokens"], + completion_tokens=response_obj["usage"]["completion_tokens"], + total_tokens=response_obj["usage"]["total_tokens"], + ) + + if "tool_use" in response_obj and response_obj["tool_use"] is not None: + completion_obj["tool_calls"] = [response_obj["tool_use"]] elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): import proto # type: ignore @@ -11900,6 +11928,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "ollama" or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "vertex_ai" + or self.custom_llm_provider == "vertex_ai_beta" or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "replicate"