mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
add vertex AI
This commit is contained in:
parent
24e51ec75c
commit
b0a60e5c91
4 changed files with 73 additions and 1 deletions
|
@ -11,6 +11,9 @@ anthropic_key = None
|
|||
replicate_key = None
|
||||
cohere_key = None
|
||||
openrouter_key = None
|
||||
vertex_project = None
|
||||
vertex_location = None
|
||||
|
||||
hugging_api_token = None
|
||||
model_cost = {
|
||||
"gpt-3.5-turbo": {"max_tokens": 4000, "input_cost_per_token": 0.0000015, "output_cost_per_token": 0.000002},
|
||||
|
@ -96,7 +99,12 @@ openrouter_models = [
|
|||
'meta-llama/llama-2-70b-chat'
|
||||
]
|
||||
|
||||
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models
|
||||
vertex_models = [
|
||||
"chat-bison",
|
||||
"chat-bison@001"
|
||||
]
|
||||
|
||||
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_models
|
||||
|
||||
####### EMBEDDING MODELS ###################
|
||||
open_ai_embedding_models = [
|
||||
|
|
|
@ -397,7 +397,32 @@ def completion(
|
|||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
response = model_response
|
||||
elif model in litellm.vertex_models:
|
||||
# import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere
|
||||
install_and_import("vertexai")
|
||||
import vertexai
|
||||
from vertexai.preview.language_models import ChatModel, InputOutputTextPair
|
||||
vertexai.init(project=litellm.vertex_project, location=litellm.vertex_location)
|
||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
|
||||
|
||||
chat_model = ChatModel.from_pretrained(model)
|
||||
|
||||
|
||||
chat = chat_model.start_chat()
|
||||
completion_response = chat.send_message(prompt, **optional_params)
|
||||
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
response = model_response
|
||||
else:
|
||||
## LOGGING
|
||||
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
||||
|
|
32
litellm/tests/test_vertex.py
Normal file
32
litellm/tests/test_vertex.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# import sys, os
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
# load_dotenv()
|
||||
# import os
|
||||
# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
|
||||
# import pytest
|
||||
# import litellm
|
||||
# from litellm import embedding, completion
|
||||
|
||||
|
||||
# litellm.vertex_project = "hardy-device-386718"
|
||||
# litellm.vertex_location = "us-central1"
|
||||
|
||||
# user_message = "what's the weather in SF "
|
||||
# messages = [{ "content": user_message,"role": "user"}]
|
||||
|
||||
# response = completion(model="chat-bison", messages=messages, temperature=0.5, top_p=0.1)
|
||||
# print(response)
|
||||
|
||||
|
||||
# # chat_model = ChatModel.from_pretrained("chat-bison@001")
|
||||
# # parameters = {
|
||||
# # "temperature": 0.2,
|
||||
# # "max_output_tokens": 256,
|
||||
# # "top_p": 0.8,
|
||||
# # "top_k": 40
|
||||
# # }
|
||||
|
||||
# # chat = chat_model.start_chat()
|
||||
# # response = chat.send_message("who are u? write a sentence", **parameters)
|
||||
# # print(f"Response from Model: {response.text}")
|
|
@ -264,6 +264,13 @@ def get_optional_params(
|
|||
optional_params["max_tokens"] = max_tokens
|
||||
if frequency_penalty != 0:
|
||||
optional_params["frequency_penalty"] = frequency_penalty
|
||||
elif model == "chat-bison": # chat-bison has diff args from chat-bison@001 ty Google
|
||||
if temperature != 1:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p != 1:
|
||||
optional_params["top_p"] = top_p
|
||||
if max_tokens != float('inf'):
|
||||
optional_params["max_output_tokens"] = max_tokens
|
||||
|
||||
else:# assume passing in params for openai/azure openai
|
||||
if functions != []:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue