diff --git a/litellm/__init__.py b/litellm/__init__.py index b7aeeb210..197c8c2e9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -107,6 +107,7 @@ open_ai_text_completion_models: List = [] cohere_models: List = [] anthropic_models: List = [] openrouter_models: List = [] +vertex_language_models: List = [] vertex_chat_models: List = [] vertex_code_chat_models: List = [] vertex_text_models: List = [] @@ -133,6 +134,8 @@ for key, value in model_cost.items(): vertex_text_models.append(key) elif value.get('litellm_provider') == 'vertex_ai-code-text-models': vertex_code_text_models.append(key) + elif value.get('litellm_provider') == 'vertex_ai-language-models': + vertex_language_models.append(key) elif value.get('litellm_provider') == 'vertex_ai-chat-models': vertex_chat_models.append(key) elif value.get('litellm_provider') == 'vertex_ai-code-chat-models': diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 1ee9f434a..3d814f22e 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -77,6 +77,8 @@ def completion( try: from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair from vertexai.language_models import TextGenerationModel, CodeGenerationModel + from vertexai.preview.generative_models import GenerativeModel, Part + vertexai.init( project=vertex_project, location=vertex_location @@ -95,7 +97,12 @@ def completion( mode = "" request_str = "" - if model in litellm.vertex_chat_models or ("chat" in model): # to catch chat-bison@003 or chat-bison@004 when google will release it + response_obj = None + if model in litellm.vertex_language_models: + chat_model = GenerativeModel(model) + mode = "" + request_str += f"chat_model = GenerativeModel({model})\n" + elif model in litellm.vertex_chat_models: chat_model = ChatModel.from_pretrained(model) mode = "chat" request_str += f"chat_model = ChatModel.from_pretrained({model})\n" @@ -112,7 +119,24 @@ def completion( mode = "chat" request_str += f"chat_model = CodeChatModel.from_pretrained({model})\n" - if mode == "chat": + if mode == "": + chat = chat_model.start_chat() + request_str+= f"chat = chat_model.start_chat()\n" + + if "stream" in optional_params and optional_params["stream"] == True: + request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n" + ## LOGGING + logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + model_response = chat.send_message(prompt, **optional_params) + optional_params["stream"] = True + return model_response + + request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" + logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + response_obj = chat.send_message(prompt, **optional_params) + completion_response = response_obj.text + response_obj = response_obj._raw_response + elif mode == "chat": chat = chat_model.start_chat() request_str+= f"chat = chat_model.start_chat()\n" @@ -161,17 +185,23 @@ def completion( model_response["created"] = int(time.time()) model_response["model"] = model ## CALCULATING USAGE - prompt_tokens = len( - encoding.encode(prompt) - ) - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) - ) - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + if model in litellm.vertex_language_models and response_obj is not None: + model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name + usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count, + completion_tokens=response_obj.usage_metadata.candidates_token_count, + total_tokens=response_obj.usage_metadata.total_token_count) + else: + prompt_tokens = len( + encoding.encode(prompt) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens + ) model_response.usage = usage return model_response except Exception as e: diff --git a/litellm/main.py b/litellm/main.py index 6204f0b60..5138607bd 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -390,7 +390,6 @@ def completion( model=deployment_id custom_llm_provider="azure" model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) - ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### if input_cost_per_token is not None and output_cost_per_token is not None: litellm.register_model({ @@ -1136,7 +1135,7 @@ def completion( ) return response response = model_response - elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models or model in litellm.vertex_text_models or model in litellm.vertex_code_text_models or custom_llm_provider == "vertex_ai": + elif custom_llm_provider == "vertex_ai": vertex_ai_project = (litellm.vertex_project or get_secret("VERTEXAI_PROJECT")) vertex_ai_location = (litellm.vertex_location diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 8b6821caa..02531abe8 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -17,7 +17,7 @@ import json import os import tempfile -litellm.num_retries = 3 +# litellm.num_retries = 3 litellm.cache = None user_message = "Write a short poem about the sky" messages = [{"content": user_message, "role": "user"}] @@ -73,6 +73,7 @@ def test_vertex_ai(): litellm.vertex_project = "hardy-device-386718" test_models = random.sample(test_models, 4) + test_models += litellm.vertex_language_models # always test gemini-pro for model in test_models: try: if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]: @@ -86,7 +87,7 @@ def test_vertex_ai(): assert len(response.choices[0].message.content) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_vertex_ai() +test_vertex_ai() def test_vertex_ai_stream(): load_vertex_ai_credentials() @@ -94,8 +95,9 @@ def test_vertex_ai_stream(): litellm.vertex_project = "hardy-device-386718" import random - test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models + test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = random.sample(test_models, 4) + test_models += litellm.vertex_language_models # always test gemini-pro for model in test_models: try: if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]: diff --git a/litellm/utils.py b/litellm/utils.py index a8f469e5c..f5aa0d15f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -126,7 +126,7 @@ def map_finish_reason(finish_reason: str): # openai supports 5 stop sequences - # cohere mapping - https://docs.cohere.com/reference/generate elif finish_reason == "COMPLETE": return "stop" - elif finish_reason == "MAX_TOKENS": + elif finish_reason == "MAX_TOKENS": # cohere + vertex ai return "length" elif finish_reason == "ERROR_TOXIC": return "content_filter" @@ -135,6 +135,10 @@ def map_finish_reason(finish_reason: str): # openai supports 5 stop sequences - # huggingface mapping https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate_stream elif finish_reason == "eos_token" or finish_reason == "stop_sequence": return "stop" + elif finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP": # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',] + return "stop" + elif finish_reason == "SAFETY": # vertex ai + return "content_filter" return finish_reason class FunctionCall(OpenAIObject): @@ -2761,12 +2765,13 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ ## openrouter elif model in litellm.maritalk_models: custom_llm_provider = "maritalk" - ## vertex - text + chat models + ## vertex - text + chat + language (gemini) models elif( model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models or model in litellm.vertex_text_models or - model in litellm.vertex_code_text_models + model in litellm.vertex_code_text_models or + model in litellm.vertex_language_models ): custom_llm_provider = "vertex_ai" ## ai21 diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index c66b5c9d8..1d0ca5038 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -325,6 +325,14 @@ "litellm_provider": "vertex_ai-code-chat-models", "mode": "chat" }, + "gemini-pro": { + "max_tokens": 30720, + "max_output_tokens": 2048, + "input_cost_per_token": 0.0000000625, + "output_cost_per_token": 0.000000125, + "litellm_provider": "vertex_ai-language-models", + "mode": "chat" + }, "palm/chat-bison": { "max_tokens": 4096, "input_cost_per_token": 0.000000125,