From b0a60e5c9138b8f742e2d5983f1c037df69b8fa0 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 10 Aug 2023 18:37:02 -0700 Subject: [PATCH] add vertex AI --- litellm/__init__.py | 10 +++++++++- litellm/main.py | 25 +++++++++++++++++++++++++ litellm/tests/test_vertex.py | 32 ++++++++++++++++++++++++++++++++ litellm/utils.py | 7 +++++++ 4 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 litellm/tests/test_vertex.py diff --git a/litellm/__init__.py b/litellm/__init__.py index ebe1bb722..5f526e0eb 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -11,6 +11,9 @@ anthropic_key = None replicate_key = None cohere_key = None openrouter_key = None +vertex_project = None +vertex_location = None + hugging_api_token = None model_cost = { "gpt-3.5-turbo": {"max_tokens": 4000, "input_cost_per_token": 0.0000015, "output_cost_per_token": 0.000002}, @@ -96,7 +99,12 @@ openrouter_models = [ 'meta-llama/llama-2-70b-chat' ] -model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models +vertex_models = [ + "chat-bison", + "chat-bison@001" +] + +model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_models ####### EMBEDDING MODELS ################### open_ai_embedding_models = [ diff --git a/litellm/main.py b/litellm/main.py index e24904424..901b80ece 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -397,7 +397,32 @@ def completion( "total_tokens": prompt_tokens + completion_tokens } response = model_response + elif model in litellm.vertex_models: + # import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere + install_and_import("vertexai") + import vertexai + from vertexai.preview.language_models import ChatModel, InputOutputTextPair + vertexai.init(project=litellm.vertex_project, location=litellm.vertex_location) + # vertexai does not use an API key, it looks for credentials.json in the environment + prompt = " ".join([message["content"] for message in messages]) + ## LOGGING + logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) + + chat_model = ChatModel.from_pretrained(model) + + + chat = chat_model.start_chat() + completion_response = chat.send_message(prompt, **optional_params) + + ## LOGGING + logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) + + ## RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = completion_response + model_response["created"] = time.time() + model_response["model"] = model + response = model_response else: ## LOGGING logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) diff --git a/litellm/tests/test_vertex.py b/litellm/tests/test_vertex.py new file mode 100644 index 000000000..4214c300b --- /dev/null +++ b/litellm/tests/test_vertex.py @@ -0,0 +1,32 @@ +# import sys, os +# import traceback +# from dotenv import load_dotenv +# load_dotenv() +# import os +# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path +# import pytest +# import litellm +# from litellm import embedding, completion + + +# litellm.vertex_project = "hardy-device-386718" +# litellm.vertex_location = "us-central1" + +# user_message = "what's the weather in SF " +# messages = [{ "content": user_message,"role": "user"}] + +# response = completion(model="chat-bison", messages=messages, temperature=0.5, top_p=0.1) +# print(response) + + +# # chat_model = ChatModel.from_pretrained("chat-bison@001") +# # parameters = { +# # "temperature": 0.2, +# # "max_output_tokens": 256, +# # "top_p": 0.8, +# # "top_k": 40 +# # } + +# # chat = chat_model.start_chat() +# # response = chat.send_message("who are u? write a sentence", **parameters) +# # print(f"Response from Model: {response.text}") \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 85be206c5..e8f5f3976 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -264,6 +264,13 @@ def get_optional_params( optional_params["max_tokens"] = max_tokens if frequency_penalty != 0: optional_params["frequency_penalty"] = frequency_penalty + elif model == "chat-bison": # chat-bison has diff args from chat-bison@001 ty Google + if temperature != 1: + optional_params["temperature"] = temperature + if top_p != 1: + optional_params["top_p"] = top_p + if max_tokens != float('inf'): + optional_params["max_output_tokens"] = max_tokens else:# assume passing in params for openai/azure openai if functions != []: