From 69c29f8f8692aa4be9d4ddc2b9bb5ba2a32daf38 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 13 Dec 2023 11:53:55 -0800 Subject: [PATCH] fix(vertex_ai.py): add support for real async streaming + completion calls --- litellm/llms/vertex_ai.py | 98 ++++++++++++++----- litellm/main.py | 2 +- .../tests/test_amazing_vertex_completion.py | 58 ++++++++--- litellm/utils.py | 4 +- model_prices_and_context_window.json | 21 ++-- 5 files changed, 134 insertions(+), 49 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index b7bc7935a..1f9878abe 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -4,7 +4,7 @@ from enum import Enum import requests import time from typing import Callable, Optional -from litellm.utils import ModelResponse, Usage +from litellm.utils import ModelResponse, Usage, CustomStreamWrapper import litellm import httpx @@ -108,37 +108,38 @@ def completion( mode = "chat" request_str += f"llm_model = ChatModel.from_pretrained({model})\n" elif model in litellm.vertex_text_models: - text_model = TextGenerationModel.from_pretrained(model) + llm_model = TextGenerationModel.from_pretrained(model) mode = "text" - request_str += f"text_model = TextGenerationModel.from_pretrained({model})\n" + request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n" elif model in litellm.vertex_code_text_models: - text_model = CodeGenerationModel.from_pretrained(model) + llm_model = CodeGenerationModel.from_pretrained(model) mode = "text" - request_str += f"text_model = CodeGenerationModel.from_pretrained({model})\n" + request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" else: # vertex_code_llm_models llm_model = CodeChatModel.from_pretrained(model) mode = "chat" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" - if acompletion == True and model in litellm.vertex_language_models: # [TODO] expand support to vertex ai chat + text models + if acompletion == True: # [TODO] expand support to vertex ai chat + text models if optional_params.get("stream", False) is True: # async streaming - pass - return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, **optional_params) + return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, **optional_params) + return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, **optional_params) if mode == "": chat = llm_model.start_chat() request_str+= f"chat = llm_model.start_chat()\n" if "stream" in optional_params and optional_params["stream"] == True: - request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n" + stream = optional_params.pop("stream") + request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params)) + model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream) optional_params["stream"] = True return model_response - request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" + request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params)) @@ -165,20 +166,19 @@ def completion( logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) completion_response = chat.send_message(prompt, **optional_params).text elif mode == "text": - if "stream" in optional_params and optional_params["stream"] == True: optional_params.pop("stream", None) # See note above on handling streaming for vertex ai - request_str += f"text_model.predict_streaming({prompt}, **{optional_params})\n" + request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - model_response = text_model.predict_streaming(prompt, **optional_params) + model_response = llm_model.predict_streaming(prompt, **optional_params) optional_params["stream"] = True return model_response - request_str += f"text_model.predict({prompt}, **{optional_params}).text\n" + request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - completion_response = text_model.predict(prompt, **optional_params).text + completion_response = llm_model.predict(prompt, **optional_params).text ## LOGGING logging_obj.post_call( @@ -216,7 +216,7 @@ def completion( except Exception as e: raise VertexAIError(status_code=500, message=str(e)) -async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params): +async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, **optional_params): """ Add support for acompletion calls for gemini-pro """ @@ -224,19 +224,31 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_ if mode == "": # gemini-pro - llm_model = llm_model.start_chat() + chat = llm_model.start_chat() ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - response_obj = await llm_model.send_message_async(prompt, generation_config=GenerationConfig(**optional_params)) + response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params)) completion_response = response_obj.text response_obj = response_obj._raw_response elif mode == "chat": # chat-bison etc. - pass + chat = llm_model.start_chat() + ## LOGGING + logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + response_obj = await chat.send_message_async(prompt, **optional_params) + completion_response = response_obj.text elif mode == "text": # gecko etc. - pass - + request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" + ## LOGGING + logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + response_obj = await llm_model.predict_async(prompt, **optional_params) + completion_response = response_obj.text + + ## LOGGING + logging_obj.post_call( + input=prompt, api_key=None, original_response=completion_response + ) ## RESPONSE OBJECT if len(str(completion_response)) > 0: @@ -252,13 +264,53 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_ usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count, completion_tokens=response_obj.usage_metadata.candidates_token_count, total_tokens=response_obj.usage_metadata.total_token_count) + else: + prompt_tokens = len( + encoding.encode(prompt) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens + ) model_response.usage = usage return model_response -def async_streaming(): +async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params): """ Add support for async streaming calls for gemini-pro """ + from vertexai.preview.generative_models import GenerationConfig + if mode == "": + # gemini-pro + chat = llm_model.start_chat() + stream = optional_params.pop("stream") + request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" + ## LOGGING + logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream) + optional_params["stream"] = True + elif mode == "chat": + chat = llm_model.start_chat() + optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params + request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n" + ## LOGGING + logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + response = chat.send_message_streaming_async(prompt, **optional_params) + optional_params["stream"] = True + elif mode == "text": + optional_params.pop("stream", None) # See note above on handling streaming for vertex ai + request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n" + ## LOGGING + logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + response = llm_model.predict_streaming_async(prompt, **optional_params) + + streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj) + async for transformed_chunk in streamwrapper: + yield transformed_chunk def embedding(): # logic for parsing in - calling - parsing out model embedding calls diff --git a/litellm/main.py b/litellm/main.py index d8850c7e6..38526bfd6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1157,7 +1157,7 @@ def completion( acompletion=acompletion ) - if "stream" in optional_params and optional_params["stream"] == True: + if "stream" in optional_params and optional_params["stream"] == True and acompletion == False: response = CustomStreamWrapper( model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging ) diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 620962128..8a227c454 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -73,7 +73,7 @@ def test_vertex_ai(): litellm.vertex_project = "hardy-device-386718" test_models = random.sample(test_models, 4) - test_models = litellm.vertex_language_models # always test gemini-pro + test_models += litellm.vertex_language_models # always test gemini-pro for model in test_models: try: if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]: @@ -87,7 +87,7 @@ def test_vertex_ai(): assert len(response.choices[0].message.content) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}") -test_vertex_ai() +# test_vertex_ai() def test_vertex_ai_stream(): load_vertex_ai_credentials() @@ -120,16 +120,48 @@ def test_vertex_ai_stream(): @pytest.mark.asyncio async def test_async_vertexai_response(): + import random load_vertex_ai_credentials() - user_message = "Hello, how are you?" - messages = [{"content": user_message, "role": "user"}] - try: - response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5) - # response = await response - print(f"response: {response}") - except litellm.Timeout as e: - pass - except Exception as e: - pytest.fail(f"An exception occurred: {e}") + test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models + test_models = random.sample(test_models, 4) + test_models += litellm.vertex_language_models # always test gemini-pro + for model in test_models: + print(f'model being tested in async call: {model}') + try: + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5) + print(f"response: {response}") + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred: {e}") -asyncio.run(test_async_vertexai_response()) \ No newline at end of file +# asyncio.run(test_async_vertexai_response()) + +@pytest.mark.asyncio +async def test_async_vertexai_streaming_response(): + import random + load_vertex_ai_credentials() + test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models + test_models = random.sample(test_models, 4) + test_models += litellm.vertex_language_models # always test gemini-pro + for model in test_models: + try: + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True) + print(f"response: {response}") + complete_response = "" + async for chunk in response: + print(f"chunk: {chunk}") + complete_response += chunk.choices[0].delta.content + print(f"complete_response: {complete_response}") + assert len(complete_response) > 0 + except litellm.Timeout as e: + pass + except Exception as e: + print(e) + pytest.fail(f"An exception occurred: {e}") + +# asyncio.run(test_async_vertexai_streaming_response()) diff --git a/litellm/utils.py b/litellm/utils.py index f5aa0d15f..c2aab4700 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -19,6 +19,7 @@ import uuid import aiohttp import logging import asyncio, httpx, inspect +from inspect import iscoroutine import copy from tokenizers import Tokenizer from dataclasses import ( @@ -5769,7 +5770,8 @@ class CustomStreamWrapper: or self.custom_llm_provider == "azure" or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "text-completion-openai" - or self.custom_llm_provider == "huggingface"): + or self.custom_llm_provider == "huggingface" + or self.custom_llm_provider == "vertex_ai"): async for chunk in self.completion_stream: if chunk == "None" or chunk is None: raise Exception diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 1d0ca5038..4b0bb2bfb 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -294,14 +294,21 @@ "max_tokens": 2048, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, - "litellm_provider": "vertex_ai-chat-models", + "litellm_provider": "vertex_ai-code-text-models", "mode": "completion" }, - "code-gecko@latest": { + "code-gecko@002": { "max_tokens": 2048, "input_cost_per_token": 0.000000125, "output_cost_per_token": 0.000000125, - "litellm_provider": "vertex_ai-chat-models", + "litellm_provider": "vertex_ai-code-text-models", + "mode": "completion" + }, + "code-gecko": { + "max_tokens": 2048, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.000000125, + "litellm_provider": "vertex_ai-code-text-models", "mode": "completion" }, "codechat-bison": { @@ -340,14 +347,6 @@ "litellm_provider": "palm", "mode": "chat" }, - "gemini-pro": { - "max_tokens": 30720, - "max_output_tokens": 2048, - "input_cost_per_token": 0.0000000625, - "output_cost_per_token": 0.000000125, - "litellm_provider": "vertex_ai-language-models", - "mode": "chat" - }, "palm/chat-bison-001": { "max_tokens": 4096, "input_cost_per_token": 0.000000125,