From 995631bd397eef763ee218d2cb63e20913d6ab34 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 12 Jun 2024 20:15:03 -0700 Subject: [PATCH] fix(vertex_httpx.py): support async completion calls --- litellm/llms/vertex_httpx.py | 107 +++++++++++++++++- litellm/main.py | 1 + .../tests/test_amazing_vertex_completion.py | 56 ++++++--- 3 files changed, 142 insertions(+), 22 deletions(-) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 70a408c2b..550fffe4a 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -275,13 +275,89 @@ class VertexLLM(BaseLLM): async def async_streaming( self, - ): - pass + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_call, + client=client, + api_base=api_base, + headers=headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="vertex_ai_beta", + logging_obj=logging_obj, + ) + return streaming_response async def async_completion( self, - ): - pass + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = AsyncHTTPHandler(**_params) # type: ignore + else: + client = client # type: ignore + + try: + response = await client.post(api_base, headers=headers, json=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + return self._process_response( + model=model, + response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) def completion( self, @@ -344,6 +420,27 @@ class VertexLLM(BaseLLM): }, ) + ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, # type: ignore + api_base=url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, # type: ignore + ) + ## SYNC STREAMING CALL ## if stream is not None and stream is True: streaming_response = CustomStreamWrapper( @@ -359,7 +456,7 @@ class VertexLLM(BaseLLM): logging_obj=logging_obj, ), model=model, - custom_llm_provider="bedrock", + custom_llm_provider="vertex_ai_beta", logging_obj=logging_obj, ) diff --git a/litellm/main.py b/litellm/main.py index 16fd394f8..83104290d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -329,6 +329,7 @@ async def acompletion( or custom_llm_provider == "ollama_chat" or custom_llm_provider == "replicate" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" or custom_llm_provider == "anthropic" diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index cf49fd130..7f0b49808 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -503,28 +503,50 @@ async def test_async_vertexai_streaming_response(): # asyncio.run(test_async_vertexai_streaming_response()) -def test_gemini_pro_vision(): +@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"]) +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_gemini_pro_vision(provider, sync_mode): try: load_vertex_ai_credentials() litellm.set_verbose = True litellm.num_retries = 3 - resp = litellm.completion( - model="vertex_ai/gemini-1.5-flash-preview-0514", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "Whats in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" + if sync_mode: + resp = litellm.completion( + model="{}/gemini-1.5-flash-preview-0514".format(provider), + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" + }, }, - }, - ], - } - ], - ) + ], + } + ], + ) + else: + resp = await litellm.acompletion( + model="{}/gemini-1.5-flash-preview-0514".format(provider), + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" + }, + }, + ], + } + ], + ) print(resp) prompt_tokens = resp.usage.prompt_tokens