add vertex AI

This commit is contained in:
ishaan-jaff 2023-08-10 18:37:02 -07:00
parent 24e51ec75c
commit b0a60e5c91
4 changed files with 73 additions and 1 deletions

View file

@ -11,6 +11,9 @@ anthropic_key = None
replicate_key = None
cohere_key = None
openrouter_key = None
vertex_project = None
vertex_location = None
hugging_api_token = None
model_cost = {
"gpt-3.5-turbo": {"max_tokens": 4000, "input_cost_per_token": 0.0000015, "output_cost_per_token": 0.000002},
@ -96,7 +99,12 @@ openrouter_models = [
'meta-llama/llama-2-70b-chat'
]
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models
vertex_models = [
"chat-bison",
"chat-bison@001"
]
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_models
####### EMBEDDING MODELS ###################
open_ai_embedding_models = [

View file

@ -397,7 +397,32 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
elif model in litellm.vertex_models:
# import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere
install_and_import("vertexai")
import vertexai
from vertexai.preview.language_models import ChatModel, InputOutputTextPair
vertexai.init(project=litellm.vertex_project, location=litellm.vertex_location)
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
chat_model = ChatModel.from_pretrained(model)
chat = chat_model.start_chat()
completion_response = chat.send_message(prompt, **optional_params)
## LOGGING
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
response = model_response
else:
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)

View file

@ -0,0 +1,32 @@
# import sys, os
# import traceback
# from dotenv import load_dotenv
# load_dotenv()
# import os
# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
# import pytest
# import litellm
# from litellm import embedding, completion
# litellm.vertex_project = "hardy-device-386718"
# litellm.vertex_location = "us-central1"
# user_message = "what's the weather in SF "
# messages = [{ "content": user_message,"role": "user"}]
# response = completion(model="chat-bison", messages=messages, temperature=0.5, top_p=0.1)
# print(response)
# # chat_model = ChatModel.from_pretrained("chat-bison@001")
# # parameters = {
# # "temperature": 0.2,
# # "max_output_tokens": 256,
# # "top_p": 0.8,
# # "top_k": 40
# # }
# # chat = chat_model.start_chat()
# # response = chat.send_message("who are u? write a sentence", **parameters)
# # print(f"Response from Model: {response.text}")

View file

@ -264,6 +264,13 @@ def get_optional_params(
optional_params["max_tokens"] = max_tokens
if frequency_penalty != 0:
optional_params["frequency_penalty"] = frequency_penalty
elif model == "chat-bison": # chat-bison has diff args from chat-bison@001 ty Google
if temperature != 1:
optional_params["temperature"] = temperature
if top_p != 1:
optional_params["top_p"] = top_p
if max_tokens != float('inf'):
optional_params["max_output_tokens"] = max_tokens
else:# assume passing in params for openai/azure openai
if functions != []: